"vscode:/vscode.git/clone" did not exist on "c07e9601164aac5cc9aa860b36d8cf0562e00982"
Unverified Commit cae9b638 authored by msbaines's avatar msbaines Committed by GitHub
Browse files

[refactor] pipe: separate out Single and MultiProcess pipe (#326)

parent eab1551a
...@@ -31,15 +31,15 @@ from fairscale.nn.model_parallel.initialize import ( ...@@ -31,15 +31,15 @@ from fairscale.nn.model_parallel.initialize import (
get_pipeline_parallel_group, get_pipeline_parallel_group,
initialize_model_parallel, initialize_model_parallel,
) )
from fairscale.nn.pipe import LazyModule, Pipe from fairscale.nn.pipe import LazyModule, MultiProcessPipe
from fairscale.utils.testing import get_worker_map, set_random_seed, torch_spawn, torch_version from fairscale.utils.testing import get_worker_map, set_random_seed, torch_spawn, torch_version
@torch_spawn([2]) @torch_spawn([2])
@pytest.mark.parametrize("pipeline_style", [Pipe.MultiProcess, Pipe.AsyncSchedule]) @pytest.mark.parametrize("pipeline_style", [MultiProcessPipe.MultiProcess, MultiProcessPipe.AsyncSchedule])
def parameters(pipeline_style): def parameters(pipeline_style):
model = nn.Sequential(nn.Linear(1, 1)) model = nn.Sequential(nn.Linear(1, 1))
pipe = Pipe(model, balance=[1], style=pipeline_style, worker_map=get_worker_map(), chunks=1) pipe = MultiProcessPipe(model, balance=[1], style=pipeline_style, worker_map=get_worker_map(), chunks=1)
if torch.distributed.get_rank() == 0: if torch.distributed.get_rank() == 0:
assert list(pipe.parameters()) != [] assert list(pipe.parameters()) != []
else: else:
...@@ -107,7 +107,7 @@ def mpi(): ...@@ -107,7 +107,7 @@ def mpi():
@torch_spawn([1]) @torch_spawn([1])
@pytest.mark.parametrize("pipeline_style", [Pipe.MultiProcess, Pipe.AsyncSchedule]) @pytest.mark.parametrize("pipeline_style", [MultiProcessPipe.MultiProcess, MultiProcessPipe.AsyncSchedule])
def public_attrs(pipeline_style): def public_attrs(pipeline_style):
class MyString: class MyString:
def __init__(self, value): def __init__(self, value):
...@@ -118,7 +118,7 @@ def public_attrs(pipeline_style): ...@@ -118,7 +118,7 @@ def public_attrs(pipeline_style):
model = nn.Sequential(nn.Linear(1, 1)) model = nn.Sequential(nn.Linear(1, 1))
pipe = Pipe( pipe = MultiProcessPipe(
model, model,
balance=(1,), balance=(1,),
style=pipeline_style, style=pipeline_style,
...@@ -127,9 +127,7 @@ def public_attrs(pipeline_style): ...@@ -127,9 +127,7 @@ def public_attrs(pipeline_style):
checkpoint=MyString("always"), checkpoint=MyString("always"),
) )
print(f"balance = {pipe.devices}")
assert pipe.balance == [1] assert pipe.balance == [1]
assert pipe.devices is None
assert pipe.chunks == 42 assert pipe.chunks == 42
assert isinstance(pipe.chunks, int) assert isinstance(pipe.chunks, int)
assert pipe.checkpoint == "always" assert pipe.checkpoint == "always"
...@@ -138,13 +136,13 @@ def public_attrs(pipeline_style): ...@@ -138,13 +136,13 @@ def public_attrs(pipeline_style):
@torch_spawn([2]) @torch_spawn([2])
@pytest.mark.parametrize("balance", [[2], [1, 1]]) @pytest.mark.parametrize("balance", [[2], [1, 1]])
@pytest.mark.parametrize("pipeline_style", [Pipe.MultiProcess, Pipe.AsyncSchedule]) @pytest.mark.parametrize("pipeline_style", [MultiProcessPipe.MultiProcess, MultiProcessPipe.AsyncSchedule])
def sequential_like(balance, pipeline_style): def sequential_like(balance, pipeline_style):
a = nn.Linear(1, 1) a = nn.Linear(1, 1)
b = nn.Linear(1, 1) b = nn.Linear(1, 1)
model = nn.Sequential(a, b) model = nn.Sequential(a, b)
model = Pipe(model, balance, style=pipeline_style, worker_map=get_worker_map()) model = MultiProcessPipe(model, balance, style=pipeline_style, worker_map=get_worker_map())
if balance == [2]: if balance == [2]:
if torch.distributed.get_rank() == 0: if torch.distributed.get_rank() == 0:
...@@ -177,7 +175,7 @@ def sequential_like(balance, pipeline_style): ...@@ -177,7 +175,7 @@ def sequential_like(balance, pipeline_style):
@torch_spawn([1]) @torch_spawn([1])
@pytest.mark.parametrize("pipeline_style", [Pipe.MultiProcess, Pipe.AsyncSchedule]) @pytest.mark.parametrize("pipeline_style", [MultiProcessPipe.MultiProcess, MultiProcessPipe.AsyncSchedule])
def balance_wrong_length(pipeline_style): def balance_wrong_length(pipeline_style):
a = nn.Linear(1, 1) a = nn.Linear(1, 1)
b = nn.Linear(1, 1) b = nn.Linear(1, 1)
...@@ -185,14 +183,14 @@ def balance_wrong_length(pipeline_style): ...@@ -185,14 +183,14 @@ def balance_wrong_length(pipeline_style):
model = nn.Sequential(a, b) model = nn.Sequential(a, b)
with pytest.raises(ValueError): with pytest.raises(ValueError):
Pipe(model, balance=[1], style=pipeline_style, worker_map=get_worker_map()) MultiProcessPipe(model, balance=[1], style=pipeline_style, worker_map=get_worker_map())
with pytest.raises(ValueError): with pytest.raises(ValueError):
Pipe(model, balance=[3], style=pipeline_style, worker_map=get_worker_map()) MultiProcessPipe(model, balance=[3], style=pipeline_style, worker_map=get_worker_map())
@torch_spawn([2]) @torch_spawn([2])
@pytest.mark.parametrize("pipeline_style", [Pipe.MultiProcess, Pipe.AsyncSchedule]) @pytest.mark.parametrize("pipeline_style", [MultiProcessPipe.MultiProcess, MultiProcessPipe.AsyncSchedule])
def balance_less_than_1(pipeline_style): def balance_less_than_1(pipeline_style):
a = nn.Linear(1, 1) a = nn.Linear(1, 1)
b = nn.Linear(1, 1) b = nn.Linear(1, 1)
...@@ -200,39 +198,39 @@ def balance_less_than_1(pipeline_style): ...@@ -200,39 +198,39 @@ def balance_less_than_1(pipeline_style):
model = nn.Sequential(a, b) model = nn.Sequential(a, b)
with pytest.raises(ValueError): with pytest.raises(ValueError):
Pipe(model, balance=[0, 2], style=pipeline_style, worker_map=get_worker_map()) MultiProcessPipe(model, balance=[0, 2], style=pipeline_style, worker_map=get_worker_map())
with pytest.raises(ValueError): with pytest.raises(ValueError):
Pipe(model, balance=[-1, 3], style=pipeline_style, worker_map=get_worker_map()) MultiProcessPipe(model, balance=[-1, 3], style=pipeline_style, worker_map=get_worker_map())
@torch_spawn([1]) @torch_spawn([1])
@pytest.mark.parametrize("pipeline_style", [Pipe.MultiProcess, Pipe.AsyncSchedule]) @pytest.mark.parametrize("pipeline_style", [MultiProcessPipe.MultiProcess, MultiProcessPipe.AsyncSchedule])
def chunks_less_than_1(pipeline_style): def chunks_less_than_1(pipeline_style):
model = nn.Sequential(nn.Linear(1, 1)) model = nn.Sequential(nn.Linear(1, 1))
with pytest.raises(ValueError): with pytest.raises(ValueError):
Pipe(model, balance=[1], style=pipeline_style, worker_map=get_worker_map(), chunks=0) MultiProcessPipe(model, balance=[1], style=pipeline_style, worker_map=get_worker_map(), chunks=0)
with pytest.raises(ValueError): with pytest.raises(ValueError):
Pipe(model, balance=[1], style=pipeline_style, worker_map=get_worker_map(), chunks=-1) MultiProcessPipe(model, balance=[1], style=pipeline_style, worker_map=get_worker_map(), chunks=-1)
@torch_spawn([1]) @torch_spawn([1])
@pytest.mark.parametrize("pipeline_style", [Pipe.MultiProcess, Pipe.AsyncSchedule]) @pytest.mark.parametrize("pipeline_style", [MultiProcessPipe.MultiProcess, MultiProcessPipe.AsyncSchedule])
def too_few_devices(pipeline_style): def too_few_devices(pipeline_style):
model = nn.Sequential(nn.Linear(1, 1), nn.Linear(1, 1), nn.Linear(1, 1), nn.Linear(1, 1)) model = nn.Sequential(nn.Linear(1, 1), nn.Linear(1, 1), nn.Linear(1, 1), nn.Linear(1, 1))
with pytest.raises(IndexError): with pytest.raises(IndexError):
# len(balance) > len(group.size()) # len(balance) > len(group.size())
model = Pipe(model, balance=[1, 1, 1, 1], style=pipeline_style, worker_map=get_worker_map()) model = MultiProcessPipe(model, balance=[1, 1, 1, 1], style=pipeline_style, worker_map=get_worker_map())
@torch_spawn([1]) @torch_spawn([1])
@pytest.mark.parametrize("pipeline_style", [Pipe.MultiProcess, Pipe.AsyncSchedule]) @pytest.mark.parametrize("pipeline_style", [MultiProcessPipe.MultiProcess, MultiProcessPipe.AsyncSchedule])
def batch_size_indivisible(pipeline_style): def batch_size_indivisible(pipeline_style):
model = nn.Sequential(nn.Linear(1, 1)) model = nn.Sequential(nn.Linear(1, 1))
model = Pipe(model, balance=[1], style=pipeline_style, worker_map=get_worker_map(), chunks=4) model = MultiProcessPipe(model, balance=[1], style=pipeline_style, worker_map=get_worker_map(), chunks=4)
with pytest.warns(None) as record: with pytest.warns(None) as record:
model(torch.rand(7, 1)) model(torch.rand(7, 1))
...@@ -242,10 +240,10 @@ def batch_size_indivisible(pipeline_style): ...@@ -242,10 +240,10 @@ def batch_size_indivisible(pipeline_style):
@torch_spawn([1]) @torch_spawn([1])
@pytest.mark.parametrize("pipeline_style", [Pipe.MultiProcess, Pipe.AsyncSchedule]) @pytest.mark.parametrize("pipeline_style", [MultiProcessPipe.MultiProcess, MultiProcessPipe.AsyncSchedule])
def batch_size_small(pipeline_style): def batch_size_small(pipeline_style):
model = nn.Sequential(nn.Linear(1, 1)) model = nn.Sequential(nn.Linear(1, 1))
model = Pipe(model, balance=[1], style=pipeline_style, worker_map=get_worker_map(), chunks=4) model = MultiProcessPipe(model, balance=[1], style=pipeline_style, worker_map=get_worker_map(), chunks=4)
with pytest.warns(None) as record: with pytest.warns(None) as record:
model(torch.rand(2, 1)) model(torch.rand(2, 1))
...@@ -255,7 +253,7 @@ def batch_size_small(pipeline_style): ...@@ -255,7 +253,7 @@ def batch_size_small(pipeline_style):
@torch_spawn([1]) @torch_spawn([1])
@pytest.mark.parametrize("pipeline_style", [Pipe.MultiProcess, Pipe.AsyncSchedule]) @pytest.mark.parametrize("pipeline_style", [MultiProcessPipe.MultiProcess, MultiProcessPipe.AsyncSchedule])
def checkpoint_mode(pipeline_style): def checkpoint_mode(pipeline_style):
def count_grad_fn(grad_fn, name, visited=set()): def count_grad_fn(grad_fn, name, visited=set()):
if grad_fn in visited: if grad_fn in visited:
...@@ -275,7 +273,7 @@ def checkpoint_mode(pipeline_style): ...@@ -275,7 +273,7 @@ def checkpoint_mode(pipeline_style):
model = nn.Sequential(nn.Linear(1, 1)) model = nn.Sequential(nn.Linear(1, 1))
input = torch.rand(2, 1) input = torch.rand(2, 1)
always = Pipe( always = MultiProcessPipe(
model, model,
balance=[1], balance=[1],
style=pipeline_style, style=pipeline_style,
...@@ -284,7 +282,7 @@ def checkpoint_mode(pipeline_style): ...@@ -284,7 +282,7 @@ def checkpoint_mode(pipeline_style):
checkpoint="always", checkpoint="always",
pipelined_backward=False, pipelined_backward=False,
) )
except_last = Pipe( except_last = MultiProcessPipe(
model, model,
balance=[1], balance=[1],
style=pipeline_style, style=pipeline_style,
...@@ -293,7 +291,7 @@ def checkpoint_mode(pipeline_style): ...@@ -293,7 +291,7 @@ def checkpoint_mode(pipeline_style):
checkpoint="except_last", checkpoint="except_last",
pipelined_backward=False, pipelined_backward=False,
) )
never = Pipe( never = MultiProcessPipe(
model, model,
balance=[1], balance=[1],
style=pipeline_style, style=pipeline_style,
...@@ -313,12 +311,12 @@ def checkpoint_mode(pipeline_style): ...@@ -313,12 +311,12 @@ def checkpoint_mode(pipeline_style):
@torch_spawn([1]) @torch_spawn([1])
@pytest.mark.parametrize("pipeline_style", [Pipe.MultiProcess, Pipe.AsyncSchedule]) @pytest.mark.parametrize("pipeline_style", [MultiProcessPipe.MultiProcess, MultiProcessPipe.AsyncSchedule])
def checkpoint_mode_invalid(pipeline_style): def checkpoint_mode_invalid(pipeline_style):
model = nn.Sequential(nn.Linear(1, 1)) model = nn.Sequential(nn.Linear(1, 1))
with pytest.raises(ValueError, match="checkpoint is not one of 'always', 'except_last', or 'never'"): with pytest.raises(ValueError, match="checkpoint is not one of 'always', 'except_last', or 'never'"):
Pipe( MultiProcessPipe(
model, model,
balance=[1], balance=[1],
style=pipeline_style, style=pipeline_style,
...@@ -329,23 +327,27 @@ def checkpoint_mode_invalid(pipeline_style): ...@@ -329,23 +327,27 @@ def checkpoint_mode_invalid(pipeline_style):
@torch_spawn([1]) @torch_spawn([1])
@pytest.mark.parametrize("pipeline_style", [Pipe.MultiProcess, Pipe.AsyncSchedule]) @pytest.mark.parametrize("pipeline_style", [MultiProcessPipe.MultiProcess, MultiProcessPipe.AsyncSchedule])
def checkpoint_mode_when_chunks_1(pipeline_style): def checkpoint_mode_when_chunks_1(pipeline_style):
model = nn.Sequential(nn.Linear(1, 1)) model = nn.Sequential(nn.Linear(1, 1))
# All checkpoint modes are fine. # All checkpoint modes are fine.
Pipe( MultiProcessPipe(
model, balance=[1], style=pipeline_style, worker_map=get_worker_map(), chunks=1, checkpoint="except_last", model, balance=[1], style=pipeline_style, worker_map=get_worker_map(), chunks=1, checkpoint="except_last",
) )
Pipe(model, balance=[1], style=pipeline_style, worker_map=get_worker_map(), chunks=1, checkpoint="always") MultiProcessPipe(
Pipe(model, balance=[1], style=pipeline_style, worker_map=get_worker_map(), chunks=1, checkpoint="never") model, balance=[1], style=pipeline_style, worker_map=get_worker_map(), chunks=1, checkpoint="always"
)
MultiProcessPipe(
model, balance=[1], style=pipeline_style, worker_map=get_worker_map(), chunks=1, checkpoint="never"
)
@torch_spawn([1]) @torch_spawn([1])
@pytest.mark.parametrize("pipeline_style", [Pipe.MultiProcess, Pipe.AsyncSchedule]) @pytest.mark.parametrize("pipeline_style", [MultiProcessPipe.MultiProcess, MultiProcessPipe.AsyncSchedule])
def checkpoint_eval(pipeline_style): def checkpoint_eval(pipeline_style):
model = nn.Sequential(nn.Linear(1, 1)) model = nn.Sequential(nn.Linear(1, 1))
model = Pipe( model = MultiProcessPipe(
model, balance=[1], style=pipeline_style, worker_map=get_worker_map(), chunks=2, pipelined_backward=False, model, balance=[1], style=pipeline_style, worker_map=get_worker_map(), chunks=2, pipelined_backward=False,
) )
input = torch.rand(2, 1) input = torch.rand(2, 1)
...@@ -373,7 +375,7 @@ def checkpoint_eval(pipeline_style): ...@@ -373,7 +375,7 @@ def checkpoint_eval(pipeline_style):
@torch_spawn([2]) @torch_spawn([2])
@pytest.mark.xfail(torch_version() < (1, 6, 0), reason="Doesn't work on torch < 1.6.0", strict=True) @pytest.mark.xfail(torch_version() < (1, 6, 0), reason="Doesn't work on torch < 1.6.0", strict=True)
@pytest.mark.parametrize("pipeline_style", [Pipe.MultiProcess, Pipe.AsyncSchedule]) @pytest.mark.parametrize("pipeline_style", [MultiProcessPipe.MultiProcess, MultiProcessPipe.AsyncSchedule])
def checkpoint_non_float_input(pipeline_style): def checkpoint_non_float_input(pipeline_style):
class ForkNonFloat(nn.Module): class ForkNonFloat(nn.Module):
def forward(self, input): def forward(self, input):
...@@ -384,7 +386,7 @@ def checkpoint_non_float_input(pipeline_style): ...@@ -384,7 +386,7 @@ def checkpoint_non_float_input(pipeline_style):
return input[0] * 2 return input[0] * 2
model = nn.Sequential(ForkNonFloat(), JoinNonFloat()) model = nn.Sequential(ForkNonFloat(), JoinNonFloat())
model = Pipe( model = MultiProcessPipe(
model, model,
balance=[1, 1], balance=[1, 1],
style=pipeline_style, style=pipeline_style,
...@@ -399,17 +401,17 @@ def checkpoint_non_float_input(pipeline_style): ...@@ -399,17 +401,17 @@ def checkpoint_non_float_input(pipeline_style):
if model.group.rank() == 1: if model.group.rank() == 1:
# with torch.autograd.detect_anomaly(): # with torch.autograd.detect_anomaly():
output.backward() output.backward()
elif pipeline_style == Pipe.MultiProcess: elif pipeline_style == MultiProcessPipe.MultiProcess:
model.back_helper(output) model.back_helper(output)
torch.distributed.barrier() torch.distributed.barrier()
@torch_spawn([1]) @torch_spawn([1])
@pytest.mark.parametrize("pipeline_style", [Pipe.MultiProcess, Pipe.AsyncSchedule]) @pytest.mark.parametrize("pipeline_style", [MultiProcessPipe.MultiProcess, MultiProcessPipe.AsyncSchedule])
def no_grad(pipeline_style): def no_grad(pipeline_style):
model = nn.Sequential(nn.Linear(1, 1)) model = nn.Sequential(nn.Linear(1, 1))
model = Pipe(model, balance=[1], style=pipeline_style, worker_map=get_worker_map(), chunks=2) model = MultiProcessPipe(model, balance=[1], style=pipeline_style, worker_map=get_worker_map(), chunks=2)
input = torch.rand(2, 1) input = torch.rand(2, 1)
latent = None latent = None
...@@ -421,7 +423,7 @@ def no_grad(pipeline_style): ...@@ -421,7 +423,7 @@ def no_grad(pipeline_style):
nonlocal latent nonlocal latent
latent = output latent = output
partition = model.mp_partitions[0] partition = model.partitions[0]
partition.module.register_forward_hook(hook) partition.module.register_forward_hook(hook)
with torch.no_grad(): with torch.no_grad():
...@@ -431,7 +433,7 @@ def no_grad(pipeline_style): ...@@ -431,7 +433,7 @@ def no_grad(pipeline_style):
@torch_spawn([1]) @torch_spawn([1])
@pytest.mark.parametrize("pipeline_style", [Pipe.MultiProcess, Pipe.AsyncSchedule]) @pytest.mark.parametrize("pipeline_style", [MultiProcessPipe.MultiProcess, MultiProcessPipe.AsyncSchedule])
def exception(pipeline_style): def exception(pipeline_style):
class ExpectedException(Exception): class ExpectedException(Exception):
pass pass
...@@ -441,7 +443,7 @@ def exception(pipeline_style): ...@@ -441,7 +443,7 @@ def exception(pipeline_style):
raise ExpectedException() raise ExpectedException()
model = nn.Sequential(Raise()) model = nn.Sequential(Raise())
model = Pipe(model, balance=[1], style=pipeline_style, worker_map=get_worker_map(), chunks=1) model = MultiProcessPipe(model, balance=[1], style=pipeline_style, worker_map=get_worker_map(), chunks=1)
with pytest.raises(ExpectedException): with pytest.raises(ExpectedException):
model(torch.rand(1)) model(torch.rand(1))
...@@ -451,7 +453,7 @@ def exception(pipeline_style): ...@@ -451,7 +453,7 @@ def exception(pipeline_style):
@torch_spawn([4]) @torch_spawn([4])
@pytest.mark.skipif(torch.cuda.is_available() and torch.cuda.device_count() < 4, reason="Not enough GPUs") @pytest.mark.skipif(torch.cuda.is_available() and torch.cuda.device_count() < 4, reason="Not enough GPUs")
@pytest.mark.xfail(strict=True) @pytest.mark.xfail(strict=True)
@pytest.mark.parametrize("pipeline_style", [Pipe.MultiProcess, Pipe.AsyncSchedule]) @pytest.mark.parametrize("pipeline_style", [MultiProcessPipe.MultiProcess, MultiProcessPipe.AsyncSchedule])
def exception_early_stop_asap(pipeline_style): def exception_early_stop_asap(pipeline_style):
"""Even the first partitions have finished to process, the partition before """Even the first partitions have finished to process, the partition before
the failed partition hould be killed as soon as possible. the failed partition hould be killed as soon as possible.
...@@ -480,7 +482,7 @@ def exception_early_stop_asap(pipeline_style): ...@@ -480,7 +482,7 @@ def exception_early_stop_asap(pipeline_style):
raise ExpectedException() raise ExpectedException()
model = nn.Sequential(Pass(), Pass(), Counter(), Raise()) model = nn.Sequential(Pass(), Pass(), Counter(), Raise())
model = Pipe(model, [1, 1, 1, 1], style=pipeline_style, worker_map=get_worker_map(), chunks=3) model = MultiProcessPipe(model, [1, 1, 1, 1], style=pipeline_style, worker_map=get_worker_map(), chunks=3)
with pytest.raises(ExpectedException): with pytest.raises(ExpectedException):
model(torch.rand(3)) model(torch.rand(3))
...@@ -490,7 +492,7 @@ def exception_early_stop_asap(pipeline_style): ...@@ -490,7 +492,7 @@ def exception_early_stop_asap(pipeline_style):
@torch_spawn([1]) @torch_spawn([1])
@pytest.mark.parametrize("pipeline_style", [Pipe.MultiProcess, Pipe.AsyncSchedule]) @pytest.mark.parametrize("pipeline_style", [MultiProcessPipe.MultiProcess, MultiProcessPipe.AsyncSchedule])
def input_pair(pipeline_style): def input_pair(pipeline_style):
class Two(nn.Module): class Two(nn.Module):
def __init__(self): def __init__(self):
...@@ -503,7 +505,7 @@ def input_pair(pipeline_style): ...@@ -503,7 +505,7 @@ def input_pair(pipeline_style):
return (self.fc_a(a), self.fc_b(b)) return (self.fc_a(a), self.fc_b(b))
model = nn.Sequential(Two()) model = nn.Sequential(Two())
model = Pipe( model = MultiProcessPipe(
model, balance=[1], style=pipeline_style, worker_map=get_worker_map(), chunks=2, pipelined_backward=False, model, balance=[1], style=pipeline_style, worker_map=get_worker_map(), chunks=2, pipelined_backward=False,
) )
...@@ -519,7 +521,7 @@ def input_pair(pipeline_style): ...@@ -519,7 +521,7 @@ def input_pair(pipeline_style):
@torch_spawn([1]) @torch_spawn([1])
@pytest.mark.parametrize("pipeline_style", [Pipe.MultiProcess, Pipe.AsyncSchedule]) @pytest.mark.parametrize("pipeline_style", [MultiProcessPipe.MultiProcess, MultiProcessPipe.AsyncSchedule])
def input_singleton(pipeline_style): def input_singleton(pipeline_style):
class One(nn.Module): class One(nn.Module):
def __init__(self): def __init__(self):
...@@ -531,7 +533,7 @@ def input_singleton(pipeline_style): ...@@ -531,7 +533,7 @@ def input_singleton(pipeline_style):
return (self.fc(a),) return (self.fc(a),)
model = nn.Sequential(One()) model = nn.Sequential(One())
model = Pipe( model = MultiProcessPipe(
model, balance=[1], style=pipeline_style, worker_map=get_worker_map(), chunks=2, pipelined_backward=False, model, balance=[1], style=pipeline_style, worker_map=get_worker_map(), chunks=2, pipelined_backward=False,
) )
...@@ -546,10 +548,10 @@ def input_singleton(pipeline_style): ...@@ -546,10 +548,10 @@ def input_singleton(pipeline_style):
@torch_spawn([1]) @torch_spawn([1])
@pytest.mark.parametrize("pipeline_style", [Pipe.MultiProcess, Pipe.AsyncSchedule]) @pytest.mark.parametrize("pipeline_style", [MultiProcessPipe.MultiProcess, MultiProcessPipe.AsyncSchedule])
def input_varargs(pipeline_style): def input_varargs(pipeline_style):
model = nn.Sequential(nn.Linear(1, 1)) model = nn.Sequential(nn.Linear(1, 1))
model = Pipe(model, balance=[1], style=pipeline_style, worker_map=get_worker_map()) model = MultiProcessPipe(model, balance=[1], style=pipeline_style, worker_map=get_worker_map())
a = torch.rand(1) a = torch.rand(1)
b = torch.rand(1) b = torch.rand(1)
...@@ -560,14 +562,14 @@ def input_varargs(pipeline_style): ...@@ -560,14 +562,14 @@ def input_varargs(pipeline_style):
@torch_spawn([1]) @torch_spawn([1])
@pytest.mark.parametrize("pipeline_style", [Pipe.MultiProcess, Pipe.AsyncSchedule]) @pytest.mark.parametrize("pipeline_style", [MultiProcessPipe.MultiProcess, MultiProcessPipe.AsyncSchedule])
def non_tensor(pipeline_style): def non_tensor(pipeline_style):
class NonTensor(nn.Module): class NonTensor(nn.Module):
def forward(self, _): def forward(self, _):
return "hello" return "hello"
model = nn.Sequential(NonTensor()) model = nn.Sequential(NonTensor())
model = Pipe(model, balance=[1], style=pipeline_style, worker_map=get_worker_map()) model = MultiProcessPipe(model, balance=[1], style=pipeline_style, worker_map=get_worker_map())
x = torch.rand(1) x = torch.rand(1)
# TypeError: expected Tensor as element 0 in argument 0, but got str # TypeError: expected Tensor as element 0 in argument 0, but got str
...@@ -580,14 +582,14 @@ def non_tensor(pipeline_style): ...@@ -580,14 +582,14 @@ def non_tensor(pipeline_style):
@torch_spawn([1]) @torch_spawn([1])
@pytest.mark.parametrize("pipeline_style", [Pipe.MultiProcess, Pipe.AsyncSchedule]) @pytest.mark.parametrize("pipeline_style", [MultiProcessPipe.MultiProcess, MultiProcessPipe.AsyncSchedule])
def non_tensor_tuple(pipeline_style): def non_tensor_tuple(pipeline_style):
class NonTensorTuple(nn.Module): class NonTensorTuple(nn.Module):
def forward(self, x): def forward(self, x):
return (x, "hello") return (x, "hello")
model = nn.Sequential(NonTensorTuple()) model = nn.Sequential(NonTensorTuple())
model = Pipe(model, balance=[1], style=pipeline_style, worker_map=get_worker_map()) model = MultiProcessPipe(model, balance=[1], style=pipeline_style, worker_map=get_worker_map())
x = torch.rand(1) x = torch.rand(1)
# TypeError: CheckpointBackward.forward: expected Variable (got str) for return value 1 # TypeError: CheckpointBackward.forward: expected Variable (got str) for return value 1
...@@ -602,7 +604,7 @@ def non_tensor_tuple(pipeline_style): ...@@ -602,7 +604,7 @@ def non_tensor_tuple(pipeline_style):
@torch_spawn([1]) @torch_spawn([1])
@pytest.mark.parametrize("checkpoint", ["never", "always", "except_last"]) @pytest.mark.parametrize("checkpoint", ["never", "always", "except_last"])
@pytest.mark.parametrize("lazy", [True, False]) @pytest.mark.parametrize("lazy", [True, False])
@pytest.mark.parametrize("pipeline_style", [Pipe.MultiProcess, Pipe.AsyncSchedule]) @pytest.mark.parametrize("pipeline_style", [MultiProcessPipe.MultiProcess, MultiProcessPipe.AsyncSchedule])
def deferred_batch_norm(checkpoint, lazy, pipeline_style): def deferred_batch_norm(checkpoint, lazy, pipeline_style):
bn = nn.BatchNorm2d(3) bn = nn.BatchNorm2d(3)
pipe_bn = deepcopy(bn) pipe_bn = deepcopy(bn)
...@@ -611,7 +613,7 @@ def deferred_batch_norm(checkpoint, lazy, pipeline_style): ...@@ -611,7 +613,7 @@ def deferred_batch_norm(checkpoint, lazy, pipeline_style):
model = [LazyModule(pipe_fn)] model = [LazyModule(pipe_fn)]
else: else:
model = nn.Sequential(pipe_bn) model = nn.Sequential(pipe_bn)
pipe = Pipe( pipe = MultiProcessPipe(
model, model,
balance=[1], balance=[1],
style=pipeline_style, style=pipeline_style,
...@@ -632,7 +634,7 @@ def deferred_batch_norm(checkpoint, lazy, pipeline_style): ...@@ -632,7 +634,7 @@ def deferred_batch_norm(checkpoint, lazy, pipeline_style):
@torch_spawn([1]) @torch_spawn([1])
@pytest.mark.parametrize("checkpoint", ["never", "always"]) @pytest.mark.parametrize("checkpoint", ["never", "always"])
@pytest.mark.parametrize("lazy", [True, False]) @pytest.mark.parametrize("lazy", [True, False])
@pytest.mark.parametrize("pipeline_style", [Pipe.MultiProcess, Pipe.AsyncSchedule]) @pytest.mark.parametrize("pipeline_style", [MultiProcessPipe.MultiProcess, MultiProcessPipe.AsyncSchedule])
def deferred_batch_norm_params(checkpoint, lazy, pipeline_style): def deferred_batch_norm_params(checkpoint, lazy, pipeline_style):
bn = nn.BatchNorm2d(3) bn = nn.BatchNorm2d(3)
pipe_bn = deepcopy(bn) pipe_bn = deepcopy(bn)
...@@ -641,7 +643,7 @@ def deferred_batch_norm_params(checkpoint, lazy, pipeline_style): ...@@ -641,7 +643,7 @@ def deferred_batch_norm_params(checkpoint, lazy, pipeline_style):
model = [LazyModule(pipe_fn)] model = [LazyModule(pipe_fn)]
else: else:
model = nn.Sequential(pipe_bn) model = nn.Sequential(pipe_bn)
pipe = Pipe( pipe = MultiProcessPipe(
model, model,
balance=[1], balance=[1],
style=pipeline_style, style=pipeline_style,
...@@ -663,7 +665,7 @@ def deferred_batch_norm_params(checkpoint, lazy, pipeline_style): ...@@ -663,7 +665,7 @@ def deferred_batch_norm_params(checkpoint, lazy, pipeline_style):
@torch_spawn([4]) @torch_spawn([4])
@pytest.mark.parametrize("pipeline_style", [Pipe.MultiProcess, Pipe.AsyncSchedule]) @pytest.mark.parametrize("pipeline_style", [MultiProcessPipe.MultiProcess, MultiProcessPipe.AsyncSchedule])
def devices(pipeline_style): def devices(pipeline_style):
a = nn.Linear(1, 1) a = nn.Linear(1, 1)
b = nn.Linear(1, 1) b = nn.Linear(1, 1)
...@@ -671,7 +673,7 @@ def devices(pipeline_style): ...@@ -671,7 +673,7 @@ def devices(pipeline_style):
# There are extra two ranks. # There are extra two ranks.
model = nn.Sequential(a, b, c) model = nn.Sequential(a, b, c)
model = Pipe(model, [1, 1, 1], style=pipeline_style, worker_map=get_worker_map()) model = MultiProcessPipe(model, [1, 1, 1], style=pipeline_style, worker_map=get_worker_map())
# Extra devices must be discarded. # Extra devices must be discarded.
if model.group.rank() == 3: if model.group.rank() == 3:
...@@ -679,17 +681,17 @@ def devices(pipeline_style): ...@@ -679,17 +681,17 @@ def devices(pipeline_style):
@torch_spawn([2]) @torch_spawn([2])
@pytest.mark.parametrize("pipeline_style", [Pipe.MultiProcess, Pipe.AsyncSchedule]) @pytest.mark.parametrize("pipeline_style", [MultiProcessPipe.MultiProcess, MultiProcessPipe.AsyncSchedule])
def partitions(pipeline_style): def partitions(pipeline_style):
a = nn.Linear(1, 1) a = nn.Linear(1, 1)
b = nn.Linear(1, 1) b = nn.Linear(1, 1)
model = nn.Sequential(a, b) model = nn.Sequential(a, b)
model = Pipe(model, [1, 1], style=pipeline_style, worker_map=get_worker_map()) model = MultiProcessPipe(model, [1, 1], style=pipeline_style, worker_map=get_worker_map())
assert isinstance(model.mp_partitions, list) assert isinstance(model.partitions, list)
assert len(model) == 1 assert len(model) == 1
assert isinstance(model.mp_partitions[0].module, nn.Sequential) assert isinstance(model.partitions[0].module, nn.Sequential)
if model.group.rank() == 0: if model.group.rank() == 0:
assert "0.0.weight" in model.state_dict() assert "0.0.weight" in model.state_dict()
...@@ -699,13 +701,13 @@ def partitions(pipeline_style): ...@@ -699,13 +701,13 @@ def partitions(pipeline_style):
@torch_spawn([2]) @torch_spawn([2])
@pytest.mark.skipif(not torch.cuda.is_available(), reason="cuda required") @pytest.mark.skipif(not torch.cuda.is_available(), reason="cuda required")
@pytest.mark.parametrize("pipeline_style", [Pipe.MultiProcess, Pipe.AsyncSchedule]) @pytest.mark.parametrize("pipeline_style", [MultiProcessPipe.MultiProcess, MultiProcessPipe.AsyncSchedule])
def deny_moving(pipeline_style): def deny_moving(pipeline_style):
a = nn.Linear(1, 1) a = nn.Linear(1, 1)
b = nn.Linear(1, 1) b = nn.Linear(1, 1)
model = nn.Sequential(a, b) model = nn.Sequential(a, b)
model = Pipe(model, [1, 1], style=pipeline_style, worker_map=get_worker_map()) model = MultiProcessPipe(model, [1, 1], style=pipeline_style, worker_map=get_worker_map())
model.cuda() model.cuda()
model.cpu() model.cpu()
...@@ -723,29 +725,29 @@ def deny_moving(pipeline_style): ...@@ -723,29 +725,29 @@ def deny_moving(pipeline_style):
@torch_spawn([1]) @torch_spawn([1])
@pytest.mark.parametrize("pipeline_style", [Pipe.MultiProcess, Pipe.AsyncSchedule]) @pytest.mark.parametrize("pipeline_style", [MultiProcessPipe.MultiProcess, MultiProcessPipe.AsyncSchedule])
def empty_module(pipeline_style): def empty_module(pipeline_style):
# Empty sequential module is not illegal. # Empty sequential module is not illegal.
model = nn.Sequential() model = nn.Sequential()
model = Pipe(model, [], style=pipeline_style, worker_map=get_worker_map()) model = MultiProcessPipe(model, [], style=pipeline_style, worker_map=get_worker_map())
assert model(torch.tensor([42])) == torch.tensor([42]) assert model(torch.tensor([42])) == torch.tensor([42])
assert model((torch.tensor([42]),)) == (torch.tensor([42]),) assert model((torch.tensor([42]),)) == (torch.tensor([42]),)
# But only tensor or tensors is legal in Pipe. # But only tensor or tensors is legal in MultiProcessPipe.
with pytest.raises(TypeError): with pytest.raises(TypeError):
model(42) model(42)
@torch_spawn([2]) @torch_spawn([2])
@pytest.mark.parametrize("pipeline_style", [Pipe.MultiProcess, Pipe.AsyncSchedule]) @pytest.mark.parametrize("pipeline_style", [MultiProcessPipe.MultiProcess, MultiProcessPipe.AsyncSchedule])
def named_children(pipeline_style): def named_children(pipeline_style):
a = nn.Linear(1, 1) a = nn.Linear(1, 1)
b = nn.Linear(1, 1) b = nn.Linear(1, 1)
model = nn.Sequential(OrderedDict([("a", a), ("b", b)])) model = nn.Sequential(OrderedDict([("a", a), ("b", b)]))
model = Pipe(model, [1, 1], style=pipeline_style, worker_map=get_worker_map()) model = MultiProcessPipe(model, [1, 1], style=pipeline_style, worker_map=get_worker_map())
names = set(n for n, _ in model.named_modules()) names = set(n for n, _ in model.named_modules())
if model.group.rank() == 0: if model.group.rank() == 0:
...@@ -753,30 +755,30 @@ def named_children(pipeline_style): ...@@ -753,30 +755,30 @@ def named_children(pipeline_style):
else: else:
assert "0.b" in names assert "0.b" in names
# Pipe doesn't support __getattr__. Unlike nn.Sequential, Pipe requires # MultiProcessPipe doesn't support __getattr__. Unlike nn.Sequential, MultiProcessPipe requires
# several methods in its namespace. # several methods in its namespace.
with pytest.raises(AttributeError): with pytest.raises(AttributeError):
model.a model.a
@torch_spawn([1]) @torch_spawn([1])
@pytest.mark.parametrize("pipeline_style", [Pipe.MultiProcess, Pipe.AsyncSchedule]) @pytest.mark.parametrize("pipeline_style", [MultiProcessPipe.MultiProcess, MultiProcessPipe.AsyncSchedule])
def recommend_auto_balance(pipeline_style): def recommend_auto_balance(pipeline_style):
with pytest.raises(ValueError, match="fairscale.nn.pipe.balance"): with pytest.raises(ValueError, match="fairscale.nn.pipe.balance"):
# balance is required # balance is required
Pipe(nn.Sequential()) MultiProcessPipe(nn.Sequential())
with pytest.raises(ValueError, match="fairscale.nn.pipe.balance"): with pytest.raises(ValueError, match="fairscale.nn.pipe.balance"):
# module and sum of balance have differen length (module: 0, sum of balance: 1) # module and sum of balance have differen length (module: 0, sum of balance: 1)
Pipe(nn.Sequential(), [1]) MultiProcessPipe(nn.Sequential(), [1])
with pytest.raises(ValueError, match="fairscale.nn.pipe.balance"): with pytest.raises(ValueError, match="fairscale.nn.pipe.balance"):
# module and sum of balance have different length (module: 2, sum of balance: 1) # module and sum of balance have different length (module: 2, sum of balance: 1)
Pipe(nn.Sequential(nn.Linear(1, 1), nn.Linear(1, 1)), [1]) MultiProcessPipe(nn.Sequential(nn.Linear(1, 1), nn.Linear(1, 1)), [1])
@torch_spawn([2]) @torch_spawn([2])
@pytest.mark.parametrize("pipeline_style", [Pipe.MultiProcess, Pipe.AsyncSchedule]) @pytest.mark.parametrize("pipeline_style", [MultiProcessPipe.MultiProcess, MultiProcessPipe.AsyncSchedule])
def lazy_construction(pipeline_style): def lazy_construction(pipeline_style):
init_count = 0 init_count = 0
...@@ -796,7 +798,7 @@ def lazy_construction(pipeline_style): ...@@ -796,7 +798,7 @@ def lazy_construction(pipeline_style):
LazyModule(lambda: Custom()), LazyModule(lambda: Custom()),
] ]
pipe = Pipe(model, balance=[2, 2], style=pipeline_style, worker_map=get_worker_map()) pipe = MultiProcessPipe(model, balance=[2, 2], style=pipeline_style, worker_map=get_worker_map())
assert isinstance(pipe[0], Custom) assert isinstance(pipe[0], Custom)
assert isinstance(pipe[1], Custom) assert isinstance(pipe[1], Custom)
...@@ -806,17 +808,17 @@ def lazy_construction(pipeline_style): ...@@ -806,17 +808,17 @@ def lazy_construction(pipeline_style):
@torch_spawn([2]) @torch_spawn([2])
@pytest.mark.skipif("OMPI_COMM_WORLD_RANK" in os.environ, reason="doesn't apply to mpi") @pytest.mark.skipif("OMPI_COMM_WORLD_RANK" in os.environ, reason="doesn't apply to mpi")
@pytest.mark.parametrize("pipeline_style", [Pipe.MultiProcess, Pipe.AsyncSchedule]) @pytest.mark.parametrize("pipeline_style", [MultiProcessPipe.MultiProcess, MultiProcessPipe.AsyncSchedule])
def missing_worker_map(pipeline_style): def missing_worker_map(pipeline_style):
model = nn.Sequential(nn.ReLU(), nn.ReLU()) model = nn.Sequential(nn.ReLU(), nn.ReLU())
with pytest.raises(ValueError, match="'RpcTransport' requires 'worker_map' to be set"): with pytest.raises(ValueError, match="'RpcTransport' requires 'worker_map' to be set"):
Pipe(model, [1, 1], style=pipeline_style) MultiProcessPipe(model, [1, 1], style=pipeline_style)
@torch_spawn([2]) @torch_spawn([2])
@pytest.mark.skip(reason="currently broken") @pytest.mark.skip(reason="currently broken")
@pytest.mark.parametrize("pipeline_style", [Pipe.MultiProcess, Pipe.AsyncSchedule]) @pytest.mark.parametrize("pipeline_style", [MultiProcessPipe.MultiProcess, MultiProcessPipe.AsyncSchedule])
def verify_module_duplicate_parameters_on_distinct_partitions(pipeline_style): def verify_module_duplicate_parameters_on_distinct_partitions(pipeline_style):
class Surrogate(nn.Module): class Surrogate(nn.Module):
def __init__(self, module): def __init__(self, module):
...@@ -828,23 +830,23 @@ def verify_module_duplicate_parameters_on_distinct_partitions(pipeline_style): ...@@ -828,23 +830,23 @@ def verify_module_duplicate_parameters_on_distinct_partitions(pipeline_style):
# FIXME(tom) can't have duplicate params with separate processes # FIXME(tom) can't have duplicate params with separate processes
with pytest.raises(ValueError, match="module with duplicate parameters on distinct devices is not supported"): with pytest.raises(ValueError, match="module with duplicate parameters on distinct devices is not supported"):
Pipe(model, [1, 1], style=pipeline_style, worker_map=get_worker_map()) MultiProcessPipe(model, [1, 1], style=pipeline_style, worker_map=get_worker_map())
@torch_spawn([4]) @torch_spawn([4])
@pytest.mark.parametrize("pipeline_style", [Pipe.MultiProcess, Pipe.AsyncSchedule]) @pytest.mark.parametrize("pipeline_style", [MultiProcessPipe.MultiProcess, MultiProcessPipe.AsyncSchedule])
def pipelined_backward(pipeline_style): def pipelined_backward(pipeline_style):
model = nn.Sequential(nn.ReLU(), nn.ReLU()) model = nn.Sequential(nn.ReLU(), nn.ReLU())
destroy_model_parallel() destroy_model_parallel()
initialize_model_parallel(1, 4) initialize_model_parallel(1, 4)
pipe = Pipe(model, [1, 1], style=pipeline_style, worker_map=get_worker_map()) pipe = MultiProcessPipe(model, [1, 1], style=pipeline_style, worker_map=get_worker_map())
assert pipe.pipelined_backward is False assert pipe.pipelined_backward is False
destroy_model_parallel() destroy_model_parallel()
initialize_model_parallel(2, 2) initialize_model_parallel(2, 2)
pipe = Pipe(model, [1, 1], style=pipeline_style, worker_map=get_worker_map()) pipe = MultiProcessPipe(model, [1, 1], style=pipeline_style, worker_map=get_worker_map())
assert pipe.pipelined_backward is True assert pipe.pipelined_backward is True
...@@ -853,7 +855,9 @@ def pipelined_backward(pipeline_style): ...@@ -853,7 +855,9 @@ def pipelined_backward(pipeline_style):
def async_event_loop(): def async_event_loop():
model = nn.Sequential(nn.Linear(10, 10), nn.ReLU(), nn.Linear(10, 10), nn.ReLU()) model = nn.Sequential(nn.Linear(10, 10), nn.ReLU(), nn.Linear(10, 10), nn.ReLU())
pipe = Pipe(model, [1, 1, 1, 1], style=Pipe.AsyncSchedule, worker_map=get_worker_map(), chunks=10) pipe = MultiProcessPipe(
model, [1, 1, 1, 1], style=MultiProcessPipe.AsyncSchedule, worker_map=get_worker_map(), chunks=10
)
inputs = torch.rand(100, 10) inputs = torch.rand(100, 10)
...@@ -869,7 +873,7 @@ def reuse_lazy(): ...@@ -869,7 +873,7 @@ def reuse_lazy():
reused = LazyModule(lambda: nn.Linear(10, 10)) reused = LazyModule(lambda: nn.Linear(10, 10))
model = [reused, nn.Linear(10, 10), nn.ReLU(), reused, nn.ReLU(), reused, nn.ReLU()] model = [reused, nn.Linear(10, 10), nn.ReLU(), reused, nn.ReLU(), reused, nn.ReLU()]
# model = [reused, reused, nn.Linear(10, 10), nn.ReLU(), reused, reused, nn.ReLU(), reused, reused, nn.ReLU()] # model = [reused, reused, nn.Linear(10, 10), nn.ReLU(), reused, reused, nn.ReLU(), reused, reused, nn.ReLU()]
pipe = Pipe(model, [3, 1, 1], style=Pipe.AsyncSchedule, worker_map=get_worker_map()) pipe = MultiProcessPipe(model, [3, 1, 1], style=MultiProcessPipe.AsyncSchedule, worker_map=get_worker_map())
pipe.eval() pipe.eval()
output = pipe(torch.rand(10)) output = pipe(torch.rand(10))
...@@ -887,7 +891,7 @@ def reuse_lazy(): ...@@ -887,7 +891,7 @@ def reuse_lazy():
# ensure identical weights but no sharing between model and pipe # ensure identical weights but no sharing between model and pipe
reused = nn.Linear(10, 10) reused = nn.Linear(10, 10)
layers = [reused, nn.Linear(10, 10), nn.ReLU(), reused, nn.ReLU(), reused, nn.ReLU()] layers = [reused, nn.Linear(10, 10), nn.ReLU(), reused, nn.ReLU(), reused, nn.ReLU()]
pipe = Pipe(layers, [3, 1, 1], style=Pipe.AsyncSchedule, worker_map=get_worker_map()) pipe = MultiProcessPipe(layers, [3, 1, 1], style=MultiProcessPipe.AsyncSchedule, worker_map=get_worker_map())
pipe.eval() pipe.eval()
model_optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9) model_optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9)
pipe_optimizer = torch.optim.SGD(pipe.parameters(), lr=0.01, momentum=0.9) if len(list(pipe.parameters())) else None pipe_optimizer = torch.optim.SGD(pipe.parameters(), lr=0.01, momentum=0.9) if len(list(pipe.parameters())) else None
...@@ -931,7 +935,7 @@ def reuse_lazy(): ...@@ -931,7 +935,7 @@ def reuse_lazy():
def test_instantiate_partition(): def test_instantiate_partition():
from fairscale.nn.pipe.async_schedule import Location from fairscale.nn.pipe.async_schedule import Location
from fairscale.nn.pipe.pipe import instantiate_partition from fairscale.nn.pipe.multiprocess_pipe import instantiate_partition
class FakeGroup: class FakeGroup:
def __init__(self, rank, size): def __init__(self, rank, size):
...@@ -947,7 +951,7 @@ def test_instantiate_partition(): ...@@ -947,7 +951,7 @@ def test_instantiate_partition():
def check_partitions(model, balance, expected_order, expected_ranks): def check_partitions(model, balance, expected_order, expected_ranks):
"""Check the instantiated model matches expectation of order and rank """Check the instantiated model matches expectation of order and rank
model: a list of modules or an nn.Sequential model: a list of modules or an nn.Sequential
balance: the balance argument to Pipe balance: the balance argument to MultiProcessPipe
expected_order: the index of modules in `model` in the order they will expected_order: the index of modules in `model` in the order they will
be executed, grouped by nn.Sequential be executed, grouped by nn.Sequential
expected_rank: the rank that each module will be executed on expected_rank: the rank that each module will be executed on
...@@ -959,7 +963,9 @@ def test_instantiate_partition(): ...@@ -959,7 +963,9 @@ def test_instantiate_partition():
# Collect `Invocation` and `Invocation` -> `ModuleWrapper` mapping from # Collect `Invocation` and `Invocation` -> `ModuleWrapper` mapping from
# instantiated model # instantiated model
for rank in range(len(balance)): for rank in range(len(balance)):
instantiated = instantiate_partition(model, balance, FakeGroup(rank, len(balance)), Pipe.AsyncSchedule) instantiated = instantiate_partition(
model, balance, FakeGroup(rank, len(balance)), MultiProcessPipe.AsyncSchedule
)
for part in instantiated: for part in instantiated:
assert isinstance(part.module, nn.Sequential) assert isinstance(part.module, nn.Sequential)
for inv in part.invocations: for inv in part.invocations:
......
...@@ -21,13 +21,13 @@ import pytest ...@@ -21,13 +21,13 @@ import pytest
import torch import torch
from torch import nn from torch import nn
from fairscale.nn import Pipe from fairscale.nn.pipe import MultiProcessPipe
from fairscale.utils.testing import get_worker_map, set_random_seed, torch_spawn from fairscale.utils.testing import get_worker_map, set_random_seed, torch_spawn
@torch_spawn([2]) @torch_spawn([2])
@pytest.mark.skipif(not torch.cuda.is_available(), reason="cuda required") @pytest.mark.skipif(not torch.cuda.is_available(), reason="cuda required")
@pytest.mark.parametrize("pipeline_style", [Pipe.MultiProcess, Pipe.AsyncSchedule]) @pytest.mark.parametrize("pipeline_style", [MultiProcessPipe.MultiProcess, MultiProcessPipe.AsyncSchedule])
def simple_linears(pipeline_style): def simple_linears(pipeline_style):
def sum_grad(parameters): def sum_grad(parameters):
return sum([p.grad.sum() for p in parameters if p.grad is not None]) return sum([p.grad.sum() for p in parameters if p.grad is not None])
...@@ -40,7 +40,7 @@ def simple_linears(pipeline_style): ...@@ -40,7 +40,7 @@ def simple_linears(pipeline_style):
inputs = torch.rand(8, 1) inputs = torch.rand(8, 1)
model = nn.Sequential(nn.Linear(1, 2), nn.Linear(2, 4), nn.Linear(4, 2), nn.Linear(2, 1),) model = nn.Sequential(nn.Linear(1, 2), nn.Linear(2, 4), nn.Linear(4, 2), nn.Linear(2, 1),)
# Without Pipe # Without MultiProcessPipe
outputs = model(inputs) outputs = model(inputs)
loss = outputs.mean() loss = outputs.mean()
loss.backward() loss.backward()
...@@ -54,20 +54,20 @@ def simple_linears(pipeline_style): ...@@ -54,20 +54,20 @@ def simple_linears(pipeline_style):
zero_grad(model.parameters()) zero_grad(model.parameters())
# With Pipe # With MultiProcessPipe
model = Pipe(model, [2, 2], style=pipeline_style, worker_map=get_worker_map(), chunks=4) model = MultiProcessPipe(model, [2, 2], style=pipeline_style, worker_map=get_worker_map(), chunks=4)
outputs = model(inputs) outputs = model(inputs)
if model.group.rank() == 1: if model.group.rank() == 1:
loss = outputs.mean() loss = outputs.mean()
loss.backward() loss.backward()
grad_with_pipe = sum_grad(model.pipeline.mp_partitions[0].module.parameters()) grad_with_pipe = sum_grad(model.pipeline.partitions[0].module.parameters())
# Both grads should be identical. # Both grads should be identical.
assert torch.allclose(grad_with_pipe, grad_without_pipe[1]) assert torch.allclose(grad_with_pipe, grad_without_pipe[1])
else: else:
model.back_helper(outputs) model.back_helper(outputs)
grad_with_pipe = sum_grad(model.pipeline.mp_partitions[0].module.parameters()) grad_with_pipe = sum_grad(model.pipeline.partitions[0].module.parameters())
# Both grads should be identical. # Both grads should be identical.
assert torch.allclose(grad_with_pipe, grad_without_pipe[0]) assert torch.allclose(grad_with_pipe, grad_without_pipe[0])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment