Unverified Commit 62635f0f authored by msbaines's avatar msbaines Committed by GitHub
Browse files

[feat] multiprocess_pipe: add support for testing gpu-gpu rpc (#552)

parent 9a6ca9bd
...@@ -22,7 +22,10 @@ if TYPE_CHECKING: ...@@ -22,7 +22,10 @@ if TYPE_CHECKING:
else: else:
Module = nn.Module Module = nn.Module
BOUNCE_TENSORS = True if torch.__version__.split("+")[0].split(".")[:3] <= ["1", "8", "1"]:
BOUNCE_TENSORS = True
else:
BOUNCE_TENSORS = False
def _verify_module(module: List[LayerSpec]) -> None: def _verify_module(module: List[LayerSpec]) -> None:
...@@ -53,7 +56,7 @@ class _ToHere(Module): ...@@ -53,7 +56,7 @@ class _ToHere(Module):
if BOUNCE_TENSORS: if BOUNCE_TENSORS:
return x_rref.remote().cpu().to_here().to(self.device) return x_rref.remote().cpu().to_here().to(self.device)
else: else:
return x_rref.to_here().to(self.device) return x_rref.to_here()
def _create_sequential(layer_spec: List[LayerSpec], device: str) -> Module: def _create_sequential(layer_spec: List[LayerSpec], device: str) -> Module:
...@@ -67,7 +70,7 @@ def _rcat(tensors: List) -> Tensor: ...@@ -67,7 +70,7 @@ def _rcat(tensors: List) -> Tensor:
def _parameter_rrefs(module: rpc.RRef) -> List[rpc.RRef]: def _parameter_rrefs(module: rpc.RRef) -> List[rpc.RRef]:
return [rpc.RRef(p) for p in module.to_here().parameters()] return [rpc.RRef(p) for p in module.local_value().parameters()]
def rloss(loss_func: Callable, input_rref: rpc.RRef, target_rref: rpc.RRef) -> rpc.RRef: def rloss(loss_func: Callable, input_rref: rpc.RRef, target_rref: rpc.RRef) -> rpc.RRef:
......
...@@ -5,6 +5,7 @@ from torch.futures import Future ...@@ -5,6 +5,7 @@ from torch.futures import Future
class RRef: class RRef:
def __init__(self, t: Any) -> None: ... def __init__(self, t: Any) -> None: ...
def local_value(self) -> Any: ...
def owner(self) -> WorkerInfo: ... def owner(self) -> WorkerInfo: ...
def remote(self) -> Any: ... def remote(self) -> Any: ...
def rpc_sync(self) -> Any: ... def rpc_sync(self) -> Any: ...
......
...@@ -21,7 +21,10 @@ import torch.nn as nn ...@@ -21,7 +21,10 @@ import torch.nn as nn
from fairscale.experimental.nn.multiprocess_pipe import DistributedLoss, MultiProcessPipe from fairscale.experimental.nn.multiprocess_pipe import DistributedLoss, MultiProcessPipe
from fairscale.utils.testing import torch_version from fairscale.utils.testing import torch_version
BOUNCE_TENSORS = True if torch_version() <= (1, 8, 1):
BOUNCE_TENSORS = True
else:
BOUNCE_TENSORS = False
CPU_DEVICES = ["worker0/cpu", "worker1/cpu"] CPU_DEVICES = ["worker0/cpu", "worker1/cpu"]
GPU_DEVICES = ["worker0/cuda:0", "worker1/cuda:1"] GPU_DEVICES = ["worker0/cuda:0", "worker1/cuda:1"]
...@@ -46,6 +49,10 @@ def rpc_worker(rank, world_size, init_file, func, *args): ...@@ -46,6 +49,10 @@ def rpc_worker(rank, world_size, init_file, func, *args):
) )
else: else:
options = rpc.TensorPipeRpcBackendOptions(init_method="file://" + init_file) options = rpc.TensorPipeRpcBackendOptions(init_method="file://" + init_file)
if torch_version() > (1, 8, 1):
for i in range(world_size):
if i != rank:
options.set_device_map("worker" + str(i), {rank: i})
rpc.init_rpc( rpc.init_rpc(
"worker" + str(rank), "worker" + str(rank),
rank=rank, rank=rank,
...@@ -124,9 +131,10 @@ def forward_chunks(devices): ...@@ -124,9 +131,10 @@ def forward_chunks(devices):
@rpc_test(world_size=2) @rpc_test(world_size=2)
@pytest.mark.parametrize("devices", DEVICES) @pytest.mark.parametrize("devices", DEVICES)
def forward_multi(devices): def forward_multi(devices):
device = devices[0].split("/")[1]
torch.random.manual_seed(3) torch.random.manual_seed(3)
torch.cuda.manual_seed_all(3) torch.cuda.manual_seed_all(3)
x = torch.randn(8, 4) x = torch.randn(8, 4).to(device)
model = [("linear1", nn.Linear, (4, 4), {}), ("relu", nn.ReLU, (), {})] model = [("linear1", nn.Linear, (4, 4), {}), ("relu", nn.ReLU, (), {})]
pipe = MultiProcessPipe(model, balance=[1, 1], chunks=4, devices=devices[:2]) pipe = MultiProcessPipe(model, balance=[1, 1], chunks=4, devices=devices[:2])
if BOUNCE_TENSORS: if BOUNCE_TENSORS:
...@@ -142,9 +150,10 @@ def forward_multi(devices): ...@@ -142,9 +150,10 @@ def forward_multi(devices):
@rpc_test(world_size=2) @rpc_test(world_size=2)
@pytest.mark.parametrize("devices", DEVICES) @pytest.mark.parametrize("devices", DEVICES)
def backward(devices): def backward(devices):
device = devices[0].split("/")[1]
torch.random.manual_seed(3) torch.random.manual_seed(3)
criterion = DistributedLoss(torch.nn.MSELoss) criterion = DistributedLoss(torch.nn.MSELoss)
x = torch.randn(8, 4) x = torch.randn(8, 4).to(device)
model = [("linear1", nn.Linear, (4, 4), {}), ("relu", nn.ReLU, (), {})] model = [("linear1", nn.Linear, (4, 4), {}), ("relu", nn.ReLU, (), {})]
pipe = MultiProcessPipe(model, balance=[1, 1], chunks=4, devices=devices[:2]) pipe = MultiProcessPipe(model, balance=[1, 1], chunks=4, devices=devices[:2])
with dist_autograd.context() as context_id: with dist_autograd.context() as context_id:
...@@ -158,9 +167,10 @@ def backward(devices): ...@@ -158,9 +167,10 @@ def backward(devices):
@rpc_test(world_size=2) @rpc_test(world_size=2)
@pytest.mark.parametrize("devices", DEVICES) @pytest.mark.parametrize("devices", DEVICES)
def update(devices): def update(devices):
device = devices[0].split("/")[1]
torch.random.manual_seed(3) torch.random.manual_seed(3)
criterion = DistributedLoss(torch.nn.MSELoss) criterion = DistributedLoss(torch.nn.MSELoss)
x = torch.randn(8, 4) x = torch.randn(8, 4).to(device)
model = [("linear1", nn.Linear, (4, 4), {}), ("relu", nn.ReLU, (), {})] model = [("linear1", nn.Linear, (4, 4), {}), ("relu", nn.ReLU, (), {})]
pipe = MultiProcessPipe(model, balance=[1, 1], chunks=4, devices=devices[:2]) pipe = MultiProcessPipe(model, balance=[1, 1], chunks=4, devices=devices[:2])
params = pipe.parameter_rrefs() params = pipe.parameter_rrefs()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment