Unverified Commit 204392e5 authored by msbaines's avatar msbaines Committed by GitHub
Browse files

[refactor] multiprocess_pipe: only support torch >= 1.9.0 (#561)

parent 34384e1b
...@@ -23,11 +23,6 @@ if TYPE_CHECKING: ...@@ -23,11 +23,6 @@ if TYPE_CHECKING:
else: else:
Module = nn.Module Module = nn.Module
if torch.__version__.split("+")[0].split(".")[:3] <= ["1", "8", "1"]:
BOUNCE_TENSORS = True
else:
BOUNCE_TENSORS = False
def _verify_module(module: List[LayerSpec]) -> None: def _verify_module(module: List[LayerSpec]) -> None:
if not isinstance(module, List): if not isinstance(module, List):
...@@ -54,9 +49,6 @@ class _ToHere(Module): ...@@ -54,9 +49,6 @@ class _ToHere(Module):
self.device = device self.device = device
def forward(self, x_rref: rpc.RRef) -> Tensor: # type: ignore def forward(self, x_rref: rpc.RRef) -> Tensor: # type: ignore
if BOUNCE_TENSORS:
return x_rref.remote().cpu().to_here().to(self.device)
else:
return x_rref.to_here() return x_rref.to_here()
...@@ -80,10 +72,7 @@ def _parameter_rrefs(module: rpc.RRef) -> List[rpc.RRef]: ...@@ -80,10 +72,7 @@ def _parameter_rrefs(module: rpc.RRef) -> List[rpc.RRef]:
return [rpc.RRef(p) for p in module.local_value().parameters()] return [rpc.RRef(p) for p in module.local_value().parameters()]
def rloss(loss_func: Callable, input_rref: rpc.RRef, target_rref: rpc.RRef) -> rpc.RRef: def _rloss(loss_func: Callable, input_rref: rpc.RRef, target_rref: rpc.RRef) -> rpc.RRef:
if BOUNCE_TENSORS:
return loss_func(input_rref.remote().cpu().to_here(), target_rref.remote().cpu().to_here())
else:
return loss_func(input_rref.to_here(), target_rref.to_here()) return loss_func(input_rref.to_here(), target_rref.to_here())
...@@ -91,7 +80,7 @@ def DistributedLoss(loss: nn.Module, *args: Tuple, **kwargs: Dict) -> Callable: ...@@ -91,7 +80,7 @@ def DistributedLoss(loss: nn.Module, *args: Tuple, **kwargs: Dict) -> Callable:
loss_func = loss(*args, **kwargs) loss_func = loss(*args, **kwargs)
def dloss(input_rref: rpc.RRef, target_rref: rpc.RRef) -> rpc.RRef: def dloss(input_rref: rpc.RRef, target_rref: rpc.RRef) -> rpc.RRef:
return rpc.remote(input_rref.owner(), rloss, args=(loss_func, input_rref, target_rref)) return rpc.remote(input_rref.owner(), _rloss, args=(loss_func, input_rref, target_rref))
return dloss return dloss
...@@ -164,6 +153,8 @@ class MultiProcessPipe(Module): ...@@ -164,6 +153,8 @@ class MultiProcessPipe(Module):
) -> None: ) -> None:
super().__init__() super().__init__()
if torch.__version__.split(".")[:2] < ["1", "9"]:
raise RuntimeError("MultiProcessPipe requires torch >= 1.9.0")
if type(chunks) is not int or chunks <= 0: if type(chunks) is not int or chunks <= 0:
raise ValueError("number of chunks must be positive integer") raise ValueError("number of chunks must be positive integer")
if checkpoint not in ["always", "except_last", "never"]: if checkpoint not in ["always", "except_last", "never"]:
......
...@@ -21,11 +21,6 @@ import torch.nn as nn ...@@ -21,11 +21,6 @@ import torch.nn as nn
from fairscale.experimental.nn.multiprocess_pipe import DistributedLoss, MultiProcessPipe from fairscale.experimental.nn.multiprocess_pipe import DistributedLoss, MultiProcessPipe
from fairscale.utils.testing import torch_version from fairscale.utils.testing import torch_version
if torch_version() <= (1, 8, 1):
BOUNCE_TENSORS = True
else:
BOUNCE_TENSORS = False
CPU_DEVICES = ["worker0/cpu", "worker1/cpu"] CPU_DEVICES = ["worker0/cpu", "worker1/cpu"]
GPU_DEVICES = ["worker0/cuda:0", "worker1/cuda:1"] GPU_DEVICES = ["worker0/cuda:0", "worker1/cuda:1"]
if torch.cuda.is_available(): if torch.cuda.is_available():
...@@ -33,25 +28,12 @@ if torch.cuda.is_available(): ...@@ -33,25 +28,12 @@ if torch.cuda.is_available():
else: else:
DEVICES = [CPU_DEVICES] DEVICES = [CPU_DEVICES]
pytestmark = pytest.mark.skipif(torch_version() < (1, 8, 0), reason="requires torch version >= 1.8.0") pytestmark = pytest.mark.skipif(torch_version() < (1, 9, 0), reason="requires torch version >= 1.9.0")
def rpc_worker(rank, world_size, init_file, func, *args): def rpc_worker(rank, world_size, init_file, func, *args):
if torch_version() == (1, 8, 0):
if torch.cuda.is_available():
# Workaround for https://github.com/pytorch/pytorch/issues/53844
options = rpc.TensorPipeRpcBackendOptions(init_method="file://" + init_file, _transports=["ibv", "uv"])
else:
# Workaround for https://github.com/pytorch/pytorch/issues/54266
options = rpc.TensorPipeRpcBackendOptions(
init_method="file://" + init_file,
_channels=["mpt_uv", "basic", "cuda_ipc", "cuda_gdr", "cuda_xth", "cuda_basic"],
)
else:
options = rpc.TensorPipeRpcBackendOptions(init_method="file://" + init_file) options = rpc.TensorPipeRpcBackendOptions(init_method="file://" + init_file)
if torch_version() > (1, 8, 1):
for i in range(world_size): for i in range(world_size):
if i != rank:
options.set_device_map("worker" + str(i), {rank: i}) options.set_device_map("worker" + str(i), {rank: i})
rpc.init_rpc( rpc.init_rpc(
"worker" + str(rank), "worker" + str(rank),
...@@ -109,8 +91,9 @@ def parameter_rrefs(devices): ...@@ -109,8 +91,9 @@ def parameter_rrefs(devices):
@rpc_test(world_size=1) @rpc_test(world_size=1)
@pytest.mark.parametrize("devices", DEVICES) @pytest.mark.parametrize("devices", DEVICES)
def forward(devices): def forward(devices):
device = devices[0].split("/")[1]
yh = torch.tensor([1.0, 0.0]) yh = torch.tensor([1.0, 0.0])
x = torch.tensor([1.0, -1.0]) x = torch.tensor([1.0, -1.0]).to(device)
model = [("relu", nn.ReLU, (), {})] model = [("relu", nn.ReLU, (), {})]
pipe = MultiProcessPipe(model, balance=[1], chunks=1, devices=devices[:1]) pipe = MultiProcessPipe(model, balance=[1], chunks=1, devices=devices[:1])
y = pipe(x).to_here().cpu() y = pipe(x).to_here().cpu()
...@@ -120,8 +103,9 @@ def forward(devices): ...@@ -120,8 +103,9 @@ def forward(devices):
@rpc_test(world_size=1) @rpc_test(world_size=1)
@pytest.mark.parametrize("devices", DEVICES) @pytest.mark.parametrize("devices", DEVICES)
def forward_chunks(devices): def forward_chunks(devices):
device = devices[0].split("/")[1]
yh = torch.tensor([1.0, 0.0, 2.0, 0.0, 3.0, 0.0, 4.0, 0.0]) yh = torch.tensor([1.0, 0.0, 2.0, 0.0, 3.0, 0.0, 4.0, 0.0])
x = torch.tensor([1.0, -1.0, 2.0, -2.0, 3.0, -3.0, 4.0, -4.0]) x = torch.tensor([1.0, -1.0, 2.0, -2.0, 3.0, -3.0, 4.0, -4.0]).to(device)
model = [("relu", nn.ReLU, (), {})] model = [("relu", nn.ReLU, (), {})]
pipe = MultiProcessPipe(model, balance=[1], chunks=4, devices=devices[:1]) pipe = MultiProcessPipe(model, balance=[1], chunks=4, devices=devices[:1])
y = pipe(x).to_here().cpu() y = pipe(x).to_here().cpu()
...@@ -139,9 +123,6 @@ def forward_multi(devices, checkpoint): ...@@ -139,9 +123,6 @@ def forward_multi(devices, checkpoint):
x.requires_grad = True # TODO(msb) remove this limitation x.requires_grad = True # TODO(msb) remove this limitation
model = [("linear1", nn.Linear, (4, 4), {}), ("relu", nn.ReLU, (), {})] model = [("linear1", nn.Linear, (4, 4), {}), ("relu", nn.ReLU, (), {})]
pipe = MultiProcessPipe(model, balance=[1, 1], chunks=4, devices=devices[:2], checkpoint=checkpoint) pipe = MultiProcessPipe(model, balance=[1, 1], chunks=4, devices=devices[:2], checkpoint=checkpoint)
if BOUNCE_TENSORS:
y = pipe(x).remote().cpu().to_here()
else:
y = pipe(x).to_here() y = pipe(x).to_here()
expected_sum = torch.tensor(5.0615) expected_sum = torch.tensor(5.0615)
assert y.shape == torch.Size([8, 4]) assert y.shape == torch.Size([8, 4])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment