Unverified Commit cae9b638 authored by msbaines's avatar msbaines Committed by GitHub
Browse files

[refactor] pipe: separate out Single and MultiProcess pipe (#326)

parent eab1551a
......@@ -31,15 +31,15 @@ from fairscale.nn.model_parallel.initialize import (
get_pipeline_parallel_group,
initialize_model_parallel,
)
from fairscale.nn.pipe import LazyModule, Pipe
from fairscale.nn.pipe import LazyModule, MultiProcessPipe
from fairscale.utils.testing import get_worker_map, set_random_seed, torch_spawn, torch_version
@torch_spawn([2])
@pytest.mark.parametrize("pipeline_style", [Pipe.MultiProcess, Pipe.AsyncSchedule])
@pytest.mark.parametrize("pipeline_style", [MultiProcessPipe.MultiProcess, MultiProcessPipe.AsyncSchedule])
def parameters(pipeline_style):
model = nn.Sequential(nn.Linear(1, 1))
pipe = Pipe(model, balance=[1], style=pipeline_style, worker_map=get_worker_map(), chunks=1)
pipe = MultiProcessPipe(model, balance=[1], style=pipeline_style, worker_map=get_worker_map(), chunks=1)
if torch.distributed.get_rank() == 0:
assert list(pipe.parameters()) != []
else:
......@@ -107,7 +107,7 @@ def mpi():
@torch_spawn([1])
@pytest.mark.parametrize("pipeline_style", [Pipe.MultiProcess, Pipe.AsyncSchedule])
@pytest.mark.parametrize("pipeline_style", [MultiProcessPipe.MultiProcess, MultiProcessPipe.AsyncSchedule])
def public_attrs(pipeline_style):
class MyString:
def __init__(self, value):
......@@ -118,7 +118,7 @@ def public_attrs(pipeline_style):
model = nn.Sequential(nn.Linear(1, 1))
pipe = Pipe(
pipe = MultiProcessPipe(
model,
balance=(1,),
style=pipeline_style,
......@@ -127,9 +127,7 @@ def public_attrs(pipeline_style):
checkpoint=MyString("always"),
)
print(f"balance = {pipe.devices}")
assert pipe.balance == [1]
assert pipe.devices is None
assert pipe.chunks == 42
assert isinstance(pipe.chunks, int)
assert pipe.checkpoint == "always"
......@@ -138,13 +136,13 @@ def public_attrs(pipeline_style):
@torch_spawn([2])
@pytest.mark.parametrize("balance", [[2], [1, 1]])
@pytest.mark.parametrize("pipeline_style", [Pipe.MultiProcess, Pipe.AsyncSchedule])
@pytest.mark.parametrize("pipeline_style", [MultiProcessPipe.MultiProcess, MultiProcessPipe.AsyncSchedule])
def sequential_like(balance, pipeline_style):
a = nn.Linear(1, 1)
b = nn.Linear(1, 1)
model = nn.Sequential(a, b)
model = Pipe(model, balance, style=pipeline_style, worker_map=get_worker_map())
model = MultiProcessPipe(model, balance, style=pipeline_style, worker_map=get_worker_map())
if balance == [2]:
if torch.distributed.get_rank() == 0:
......@@ -177,7 +175,7 @@ def sequential_like(balance, pipeline_style):
@torch_spawn([1])
@pytest.mark.parametrize("pipeline_style", [Pipe.MultiProcess, Pipe.AsyncSchedule])
@pytest.mark.parametrize("pipeline_style", [MultiProcessPipe.MultiProcess, MultiProcessPipe.AsyncSchedule])
def balance_wrong_length(pipeline_style):
a = nn.Linear(1, 1)
b = nn.Linear(1, 1)
......@@ -185,14 +183,14 @@ def balance_wrong_length(pipeline_style):
model = nn.Sequential(a, b)
with pytest.raises(ValueError):
Pipe(model, balance=[1], style=pipeline_style, worker_map=get_worker_map())
MultiProcessPipe(model, balance=[1], style=pipeline_style, worker_map=get_worker_map())
with pytest.raises(ValueError):
Pipe(model, balance=[3], style=pipeline_style, worker_map=get_worker_map())
MultiProcessPipe(model, balance=[3], style=pipeline_style, worker_map=get_worker_map())
@torch_spawn([2])
@pytest.mark.parametrize("pipeline_style", [Pipe.MultiProcess, Pipe.AsyncSchedule])
@pytest.mark.parametrize("pipeline_style", [MultiProcessPipe.MultiProcess, MultiProcessPipe.AsyncSchedule])
def balance_less_than_1(pipeline_style):
a = nn.Linear(1, 1)
b = nn.Linear(1, 1)
......@@ -200,39 +198,39 @@ def balance_less_than_1(pipeline_style):
model = nn.Sequential(a, b)
with pytest.raises(ValueError):
Pipe(model, balance=[0, 2], style=pipeline_style, worker_map=get_worker_map())
MultiProcessPipe(model, balance=[0, 2], style=pipeline_style, worker_map=get_worker_map())
with pytest.raises(ValueError):
Pipe(model, balance=[-1, 3], style=pipeline_style, worker_map=get_worker_map())
MultiProcessPipe(model, balance=[-1, 3], style=pipeline_style, worker_map=get_worker_map())
@torch_spawn([1])
@pytest.mark.parametrize("pipeline_style", [Pipe.MultiProcess, Pipe.AsyncSchedule])
@pytest.mark.parametrize("pipeline_style", [MultiProcessPipe.MultiProcess, MultiProcessPipe.AsyncSchedule])
def chunks_less_than_1(pipeline_style):
model = nn.Sequential(nn.Linear(1, 1))
with pytest.raises(ValueError):
Pipe(model, balance=[1], style=pipeline_style, worker_map=get_worker_map(), chunks=0)
MultiProcessPipe(model, balance=[1], style=pipeline_style, worker_map=get_worker_map(), chunks=0)
with pytest.raises(ValueError):
Pipe(model, balance=[1], style=pipeline_style, worker_map=get_worker_map(), chunks=-1)
MultiProcessPipe(model, balance=[1], style=pipeline_style, worker_map=get_worker_map(), chunks=-1)
@torch_spawn([1])
@pytest.mark.parametrize("pipeline_style", [Pipe.MultiProcess, Pipe.AsyncSchedule])
@pytest.mark.parametrize("pipeline_style", [MultiProcessPipe.MultiProcess, MultiProcessPipe.AsyncSchedule])
def too_few_devices(pipeline_style):
model = nn.Sequential(nn.Linear(1, 1), nn.Linear(1, 1), nn.Linear(1, 1), nn.Linear(1, 1))
with pytest.raises(IndexError):
# len(balance) > len(group.size())
model = Pipe(model, balance=[1, 1, 1, 1], style=pipeline_style, worker_map=get_worker_map())
model = MultiProcessPipe(model, balance=[1, 1, 1, 1], style=pipeline_style, worker_map=get_worker_map())
@torch_spawn([1])
@pytest.mark.parametrize("pipeline_style", [Pipe.MultiProcess, Pipe.AsyncSchedule])
@pytest.mark.parametrize("pipeline_style", [MultiProcessPipe.MultiProcess, MultiProcessPipe.AsyncSchedule])
def batch_size_indivisible(pipeline_style):
model = nn.Sequential(nn.Linear(1, 1))
model = Pipe(model, balance=[1], style=pipeline_style, worker_map=get_worker_map(), chunks=4)
model = MultiProcessPipe(model, balance=[1], style=pipeline_style, worker_map=get_worker_map(), chunks=4)
with pytest.warns(None) as record:
model(torch.rand(7, 1))
......@@ -242,10 +240,10 @@ def batch_size_indivisible(pipeline_style):
@torch_spawn([1])
@pytest.mark.parametrize("pipeline_style", [Pipe.MultiProcess, Pipe.AsyncSchedule])
@pytest.mark.parametrize("pipeline_style", [MultiProcessPipe.MultiProcess, MultiProcessPipe.AsyncSchedule])
def batch_size_small(pipeline_style):
model = nn.Sequential(nn.Linear(1, 1))
model = Pipe(model, balance=[1], style=pipeline_style, worker_map=get_worker_map(), chunks=4)
model = MultiProcessPipe(model, balance=[1], style=pipeline_style, worker_map=get_worker_map(), chunks=4)
with pytest.warns(None) as record:
model(torch.rand(2, 1))
......@@ -255,7 +253,7 @@ def batch_size_small(pipeline_style):
@torch_spawn([1])
@pytest.mark.parametrize("pipeline_style", [Pipe.MultiProcess, Pipe.AsyncSchedule])
@pytest.mark.parametrize("pipeline_style", [MultiProcessPipe.MultiProcess, MultiProcessPipe.AsyncSchedule])
def checkpoint_mode(pipeline_style):
def count_grad_fn(grad_fn, name, visited=set()):
if grad_fn in visited:
......@@ -275,7 +273,7 @@ def checkpoint_mode(pipeline_style):
model = nn.Sequential(nn.Linear(1, 1))
input = torch.rand(2, 1)
always = Pipe(
always = MultiProcessPipe(
model,
balance=[1],
style=pipeline_style,
......@@ -284,7 +282,7 @@ def checkpoint_mode(pipeline_style):
checkpoint="always",
pipelined_backward=False,
)
except_last = Pipe(
except_last = MultiProcessPipe(
model,
balance=[1],
style=pipeline_style,
......@@ -293,7 +291,7 @@ def checkpoint_mode(pipeline_style):
checkpoint="except_last",
pipelined_backward=False,
)
never = Pipe(
never = MultiProcessPipe(
model,
balance=[1],
style=pipeline_style,
......@@ -313,12 +311,12 @@ def checkpoint_mode(pipeline_style):
@torch_spawn([1])
@pytest.mark.parametrize("pipeline_style", [Pipe.MultiProcess, Pipe.AsyncSchedule])
@pytest.mark.parametrize("pipeline_style", [MultiProcessPipe.MultiProcess, MultiProcessPipe.AsyncSchedule])
def checkpoint_mode_invalid(pipeline_style):
model = nn.Sequential(nn.Linear(1, 1))
with pytest.raises(ValueError, match="checkpoint is not one of 'always', 'except_last', or 'never'"):
Pipe(
MultiProcessPipe(
model,
balance=[1],
style=pipeline_style,
......@@ -329,23 +327,27 @@ def checkpoint_mode_invalid(pipeline_style):
@torch_spawn([1])
@pytest.mark.parametrize("pipeline_style", [Pipe.MultiProcess, Pipe.AsyncSchedule])
@pytest.mark.parametrize("pipeline_style", [MultiProcessPipe.MultiProcess, MultiProcessPipe.AsyncSchedule])
def checkpoint_mode_when_chunks_1(pipeline_style):
model = nn.Sequential(nn.Linear(1, 1))
# All checkpoint modes are fine.
Pipe(
MultiProcessPipe(
model, balance=[1], style=pipeline_style, worker_map=get_worker_map(), chunks=1, checkpoint="except_last",
)
Pipe(model, balance=[1], style=pipeline_style, worker_map=get_worker_map(), chunks=1, checkpoint="always")
Pipe(model, balance=[1], style=pipeline_style, worker_map=get_worker_map(), chunks=1, checkpoint="never")
MultiProcessPipe(
model, balance=[1], style=pipeline_style, worker_map=get_worker_map(), chunks=1, checkpoint="always"
)
MultiProcessPipe(
model, balance=[1], style=pipeline_style, worker_map=get_worker_map(), chunks=1, checkpoint="never"
)
@torch_spawn([1])
@pytest.mark.parametrize("pipeline_style", [Pipe.MultiProcess, Pipe.AsyncSchedule])
@pytest.mark.parametrize("pipeline_style", [MultiProcessPipe.MultiProcess, MultiProcessPipe.AsyncSchedule])
def checkpoint_eval(pipeline_style):
model = nn.Sequential(nn.Linear(1, 1))
model = Pipe(
model = MultiProcessPipe(
model, balance=[1], style=pipeline_style, worker_map=get_worker_map(), chunks=2, pipelined_backward=False,
)
input = torch.rand(2, 1)
......@@ -373,7 +375,7 @@ def checkpoint_eval(pipeline_style):
@torch_spawn([2])
@pytest.mark.xfail(torch_version() < (1, 6, 0), reason="Doesn't work on torch < 1.6.0", strict=True)
@pytest.mark.parametrize("pipeline_style", [Pipe.MultiProcess, Pipe.AsyncSchedule])
@pytest.mark.parametrize("pipeline_style", [MultiProcessPipe.MultiProcess, MultiProcessPipe.AsyncSchedule])
def checkpoint_non_float_input(pipeline_style):
class ForkNonFloat(nn.Module):
def forward(self, input):
......@@ -384,7 +386,7 @@ def checkpoint_non_float_input(pipeline_style):
return input[0] * 2
model = nn.Sequential(ForkNonFloat(), JoinNonFloat())
model = Pipe(
model = MultiProcessPipe(
model,
balance=[1, 1],
style=pipeline_style,
......@@ -399,17 +401,17 @@ def checkpoint_non_float_input(pipeline_style):
if model.group.rank() == 1:
# with torch.autograd.detect_anomaly():
output.backward()
elif pipeline_style == Pipe.MultiProcess:
elif pipeline_style == MultiProcessPipe.MultiProcess:
model.back_helper(output)
torch.distributed.barrier()
@torch_spawn([1])
@pytest.mark.parametrize("pipeline_style", [Pipe.MultiProcess, Pipe.AsyncSchedule])
@pytest.mark.parametrize("pipeline_style", [MultiProcessPipe.MultiProcess, MultiProcessPipe.AsyncSchedule])
def no_grad(pipeline_style):
model = nn.Sequential(nn.Linear(1, 1))
model = Pipe(model, balance=[1], style=pipeline_style, worker_map=get_worker_map(), chunks=2)
model = MultiProcessPipe(model, balance=[1], style=pipeline_style, worker_map=get_worker_map(), chunks=2)
input = torch.rand(2, 1)
latent = None
......@@ -421,7 +423,7 @@ def no_grad(pipeline_style):
nonlocal latent
latent = output
partition = model.mp_partitions[0]
partition = model.partitions[0]
partition.module.register_forward_hook(hook)
with torch.no_grad():
......@@ -431,7 +433,7 @@ def no_grad(pipeline_style):
@torch_spawn([1])
@pytest.mark.parametrize("pipeline_style", [Pipe.MultiProcess, Pipe.AsyncSchedule])
@pytest.mark.parametrize("pipeline_style", [MultiProcessPipe.MultiProcess, MultiProcessPipe.AsyncSchedule])
def exception(pipeline_style):
class ExpectedException(Exception):
pass
......@@ -441,7 +443,7 @@ def exception(pipeline_style):
raise ExpectedException()
model = nn.Sequential(Raise())
model = Pipe(model, balance=[1], style=pipeline_style, worker_map=get_worker_map(), chunks=1)
model = MultiProcessPipe(model, balance=[1], style=pipeline_style, worker_map=get_worker_map(), chunks=1)
with pytest.raises(ExpectedException):
model(torch.rand(1))
......@@ -451,7 +453,7 @@ def exception(pipeline_style):
@torch_spawn([4])
@pytest.mark.skipif(torch.cuda.is_available() and torch.cuda.device_count() < 4, reason="Not enough GPUs")
@pytest.mark.xfail(strict=True)
@pytest.mark.parametrize("pipeline_style", [Pipe.MultiProcess, Pipe.AsyncSchedule])
@pytest.mark.parametrize("pipeline_style", [MultiProcessPipe.MultiProcess, MultiProcessPipe.AsyncSchedule])
def exception_early_stop_asap(pipeline_style):
"""Even the first partitions have finished to process, the partition before
the failed partition hould be killed as soon as possible.
......@@ -480,7 +482,7 @@ def exception_early_stop_asap(pipeline_style):
raise ExpectedException()
model = nn.Sequential(Pass(), Pass(), Counter(), Raise())
model = Pipe(model, [1, 1, 1, 1], style=pipeline_style, worker_map=get_worker_map(), chunks=3)
model = MultiProcessPipe(model, [1, 1, 1, 1], style=pipeline_style, worker_map=get_worker_map(), chunks=3)
with pytest.raises(ExpectedException):
model(torch.rand(3))
......@@ -490,7 +492,7 @@ def exception_early_stop_asap(pipeline_style):
@torch_spawn([1])
@pytest.mark.parametrize("pipeline_style", [Pipe.MultiProcess, Pipe.AsyncSchedule])
@pytest.mark.parametrize("pipeline_style", [MultiProcessPipe.MultiProcess, MultiProcessPipe.AsyncSchedule])
def input_pair(pipeline_style):
class Two(nn.Module):
def __init__(self):
......@@ -503,7 +505,7 @@ def input_pair(pipeline_style):
return (self.fc_a(a), self.fc_b(b))
model = nn.Sequential(Two())
model = Pipe(
model = MultiProcessPipe(
model, balance=[1], style=pipeline_style, worker_map=get_worker_map(), chunks=2, pipelined_backward=False,
)
......@@ -519,7 +521,7 @@ def input_pair(pipeline_style):
@torch_spawn([1])
@pytest.mark.parametrize("pipeline_style", [Pipe.MultiProcess, Pipe.AsyncSchedule])
@pytest.mark.parametrize("pipeline_style", [MultiProcessPipe.MultiProcess, MultiProcessPipe.AsyncSchedule])
def input_singleton(pipeline_style):
class One(nn.Module):
def __init__(self):
......@@ -531,7 +533,7 @@ def input_singleton(pipeline_style):
return (self.fc(a),)
model = nn.Sequential(One())
model = Pipe(
model = MultiProcessPipe(
model, balance=[1], style=pipeline_style, worker_map=get_worker_map(), chunks=2, pipelined_backward=False,
)
......@@ -546,10 +548,10 @@ def input_singleton(pipeline_style):
@torch_spawn([1])
@pytest.mark.parametrize("pipeline_style", [Pipe.MultiProcess, Pipe.AsyncSchedule])
@pytest.mark.parametrize("pipeline_style", [MultiProcessPipe.MultiProcess, MultiProcessPipe.AsyncSchedule])
def input_varargs(pipeline_style):
model = nn.Sequential(nn.Linear(1, 1))
model = Pipe(model, balance=[1], style=pipeline_style, worker_map=get_worker_map())
model = MultiProcessPipe(model, balance=[1], style=pipeline_style, worker_map=get_worker_map())
a = torch.rand(1)
b = torch.rand(1)
......@@ -560,14 +562,14 @@ def input_varargs(pipeline_style):
@torch_spawn([1])
@pytest.mark.parametrize("pipeline_style", [Pipe.MultiProcess, Pipe.AsyncSchedule])
@pytest.mark.parametrize("pipeline_style", [MultiProcessPipe.MultiProcess, MultiProcessPipe.AsyncSchedule])
def non_tensor(pipeline_style):
class NonTensor(nn.Module):
def forward(self, _):
return "hello"
model = nn.Sequential(NonTensor())
model = Pipe(model, balance=[1], style=pipeline_style, worker_map=get_worker_map())
model = MultiProcessPipe(model, balance=[1], style=pipeline_style, worker_map=get_worker_map())
x = torch.rand(1)
# TypeError: expected Tensor as element 0 in argument 0, but got str
......@@ -580,14 +582,14 @@ def non_tensor(pipeline_style):
@torch_spawn([1])
@pytest.mark.parametrize("pipeline_style", [Pipe.MultiProcess, Pipe.AsyncSchedule])
@pytest.mark.parametrize("pipeline_style", [MultiProcessPipe.MultiProcess, MultiProcessPipe.AsyncSchedule])
def non_tensor_tuple(pipeline_style):
class NonTensorTuple(nn.Module):
def forward(self, x):
return (x, "hello")
model = nn.Sequential(NonTensorTuple())
model = Pipe(model, balance=[1], style=pipeline_style, worker_map=get_worker_map())
model = MultiProcessPipe(model, balance=[1], style=pipeline_style, worker_map=get_worker_map())
x = torch.rand(1)
# TypeError: CheckpointBackward.forward: expected Variable (got str) for return value 1
......@@ -602,7 +604,7 @@ def non_tensor_tuple(pipeline_style):
@torch_spawn([1])
@pytest.mark.parametrize("checkpoint", ["never", "always", "except_last"])
@pytest.mark.parametrize("lazy", [True, False])
@pytest.mark.parametrize("pipeline_style", [Pipe.MultiProcess, Pipe.AsyncSchedule])
@pytest.mark.parametrize("pipeline_style", [MultiProcessPipe.MultiProcess, MultiProcessPipe.AsyncSchedule])
def deferred_batch_norm(checkpoint, lazy, pipeline_style):
bn = nn.BatchNorm2d(3)
pipe_bn = deepcopy(bn)
......@@ -611,7 +613,7 @@ def deferred_batch_norm(checkpoint, lazy, pipeline_style):
model = [LazyModule(pipe_fn)]
else:
model = nn.Sequential(pipe_bn)
pipe = Pipe(
pipe = MultiProcessPipe(
model,
balance=[1],
style=pipeline_style,
......@@ -632,7 +634,7 @@ def deferred_batch_norm(checkpoint, lazy, pipeline_style):
@torch_spawn([1])
@pytest.mark.parametrize("checkpoint", ["never", "always"])
@pytest.mark.parametrize("lazy", [True, False])
@pytest.mark.parametrize("pipeline_style", [Pipe.MultiProcess, Pipe.AsyncSchedule])
@pytest.mark.parametrize("pipeline_style", [MultiProcessPipe.MultiProcess, MultiProcessPipe.AsyncSchedule])
def deferred_batch_norm_params(checkpoint, lazy, pipeline_style):
bn = nn.BatchNorm2d(3)
pipe_bn = deepcopy(bn)
......@@ -641,7 +643,7 @@ def deferred_batch_norm_params(checkpoint, lazy, pipeline_style):
model = [LazyModule(pipe_fn)]
else:
model = nn.Sequential(pipe_bn)
pipe = Pipe(
pipe = MultiProcessPipe(
model,
balance=[1],
style=pipeline_style,
......@@ -663,7 +665,7 @@ def deferred_batch_norm_params(checkpoint, lazy, pipeline_style):
@torch_spawn([4])
@pytest.mark.parametrize("pipeline_style", [Pipe.MultiProcess, Pipe.AsyncSchedule])
@pytest.mark.parametrize("pipeline_style", [MultiProcessPipe.MultiProcess, MultiProcessPipe.AsyncSchedule])
def devices(pipeline_style):
a = nn.Linear(1, 1)
b = nn.Linear(1, 1)
......@@ -671,7 +673,7 @@ def devices(pipeline_style):
# There are extra two ranks.
model = nn.Sequential(a, b, c)
model = Pipe(model, [1, 1, 1], style=pipeline_style, worker_map=get_worker_map())
model = MultiProcessPipe(model, [1, 1, 1], style=pipeline_style, worker_map=get_worker_map())
# Extra devices must be discarded.
if model.group.rank() == 3:
......@@ -679,17 +681,17 @@ def devices(pipeline_style):
@torch_spawn([2])
@pytest.mark.parametrize("pipeline_style", [Pipe.MultiProcess, Pipe.AsyncSchedule])
@pytest.mark.parametrize("pipeline_style", [MultiProcessPipe.MultiProcess, MultiProcessPipe.AsyncSchedule])
def partitions(pipeline_style):
a = nn.Linear(1, 1)
b = nn.Linear(1, 1)
model = nn.Sequential(a, b)
model = Pipe(model, [1, 1], style=pipeline_style, worker_map=get_worker_map())
model = MultiProcessPipe(model, [1, 1], style=pipeline_style, worker_map=get_worker_map())
assert isinstance(model.mp_partitions, list)
assert isinstance(model.partitions, list)
assert len(model) == 1
assert isinstance(model.mp_partitions[0].module, nn.Sequential)
assert isinstance(model.partitions[0].module, nn.Sequential)
if model.group.rank() == 0:
assert "0.0.weight" in model.state_dict()
......@@ -699,13 +701,13 @@ def partitions(pipeline_style):
@torch_spawn([2])
@pytest.mark.skipif(not torch.cuda.is_available(), reason="cuda required")
@pytest.mark.parametrize("pipeline_style", [Pipe.MultiProcess, Pipe.AsyncSchedule])
@pytest.mark.parametrize("pipeline_style", [MultiProcessPipe.MultiProcess, MultiProcessPipe.AsyncSchedule])
def deny_moving(pipeline_style):
a = nn.Linear(1, 1)
b = nn.Linear(1, 1)
model = nn.Sequential(a, b)
model = Pipe(model, [1, 1], style=pipeline_style, worker_map=get_worker_map())
model = MultiProcessPipe(model, [1, 1], style=pipeline_style, worker_map=get_worker_map())
model.cuda()
model.cpu()
......@@ -723,29 +725,29 @@ def deny_moving(pipeline_style):
@torch_spawn([1])
@pytest.mark.parametrize("pipeline_style", [Pipe.MultiProcess, Pipe.AsyncSchedule])
@pytest.mark.parametrize("pipeline_style", [MultiProcessPipe.MultiProcess, MultiProcessPipe.AsyncSchedule])
def empty_module(pipeline_style):
# Empty sequential module is not illegal.
model = nn.Sequential()
model = Pipe(model, [], style=pipeline_style, worker_map=get_worker_map())
model = MultiProcessPipe(model, [], style=pipeline_style, worker_map=get_worker_map())
assert model(torch.tensor([42])) == torch.tensor([42])
assert model((torch.tensor([42]),)) == (torch.tensor([42]),)
# But only tensor or tensors is legal in Pipe.
# But only tensor or tensors is legal in MultiProcessPipe.
with pytest.raises(TypeError):
model(42)
@torch_spawn([2])
@pytest.mark.parametrize("pipeline_style", [Pipe.MultiProcess, Pipe.AsyncSchedule])
@pytest.mark.parametrize("pipeline_style", [MultiProcessPipe.MultiProcess, MultiProcessPipe.AsyncSchedule])
def named_children(pipeline_style):
a = nn.Linear(1, 1)
b = nn.Linear(1, 1)
model = nn.Sequential(OrderedDict([("a", a), ("b", b)]))
model = Pipe(model, [1, 1], style=pipeline_style, worker_map=get_worker_map())
model = MultiProcessPipe(model, [1, 1], style=pipeline_style, worker_map=get_worker_map())
names = set(n for n, _ in model.named_modules())
if model.group.rank() == 0:
......@@ -753,30 +755,30 @@ def named_children(pipeline_style):
else:
assert "0.b" in names
# Pipe doesn't support __getattr__. Unlike nn.Sequential, Pipe requires
# MultiProcessPipe doesn't support __getattr__. Unlike nn.Sequential, MultiProcessPipe requires
# several methods in its namespace.
with pytest.raises(AttributeError):
model.a
@torch_spawn([1])
@pytest.mark.parametrize("pipeline_style", [Pipe.MultiProcess, Pipe.AsyncSchedule])
@pytest.mark.parametrize("pipeline_style", [MultiProcessPipe.MultiProcess, MultiProcessPipe.AsyncSchedule])
def recommend_auto_balance(pipeline_style):
with pytest.raises(ValueError, match="fairscale.nn.pipe.balance"):
# balance is required
Pipe(nn.Sequential())
MultiProcessPipe(nn.Sequential())
with pytest.raises(ValueError, match="fairscale.nn.pipe.balance"):
# module and sum of balance have differen length (module: 0, sum of balance: 1)
Pipe(nn.Sequential(), [1])
MultiProcessPipe(nn.Sequential(), [1])
with pytest.raises(ValueError, match="fairscale.nn.pipe.balance"):
# module and sum of balance have different length (module: 2, sum of balance: 1)
Pipe(nn.Sequential(nn.Linear(1, 1), nn.Linear(1, 1)), [1])
MultiProcessPipe(nn.Sequential(nn.Linear(1, 1), nn.Linear(1, 1)), [1])
@torch_spawn([2])
@pytest.mark.parametrize("pipeline_style", [Pipe.MultiProcess, Pipe.AsyncSchedule])
@pytest.mark.parametrize("pipeline_style", [MultiProcessPipe.MultiProcess, MultiProcessPipe.AsyncSchedule])
def lazy_construction(pipeline_style):
init_count = 0
......@@ -796,7 +798,7 @@ def lazy_construction(pipeline_style):
LazyModule(lambda: Custom()),
]
pipe = Pipe(model, balance=[2, 2], style=pipeline_style, worker_map=get_worker_map())
pipe = MultiProcessPipe(model, balance=[2, 2], style=pipeline_style, worker_map=get_worker_map())
assert isinstance(pipe[0], Custom)
assert isinstance(pipe[1], Custom)
......@@ -806,17 +808,17 @@ def lazy_construction(pipeline_style):
@torch_spawn([2])
@pytest.mark.skipif("OMPI_COMM_WORLD_RANK" in os.environ, reason="doesn't apply to mpi")
@pytest.mark.parametrize("pipeline_style", [Pipe.MultiProcess, Pipe.AsyncSchedule])
@pytest.mark.parametrize("pipeline_style", [MultiProcessPipe.MultiProcess, MultiProcessPipe.AsyncSchedule])
def missing_worker_map(pipeline_style):
model = nn.Sequential(nn.ReLU(), nn.ReLU())
with pytest.raises(ValueError, match="'RpcTransport' requires 'worker_map' to be set"):
Pipe(model, [1, 1], style=pipeline_style)
MultiProcessPipe(model, [1, 1], style=pipeline_style)
@torch_spawn([2])
@pytest.mark.skip(reason="currently broken")
@pytest.mark.parametrize("pipeline_style", [Pipe.MultiProcess, Pipe.AsyncSchedule])
@pytest.mark.parametrize("pipeline_style", [MultiProcessPipe.MultiProcess, MultiProcessPipe.AsyncSchedule])
def verify_module_duplicate_parameters_on_distinct_partitions(pipeline_style):
class Surrogate(nn.Module):
def __init__(self, module):
......@@ -828,23 +830,23 @@ def verify_module_duplicate_parameters_on_distinct_partitions(pipeline_style):
# FIXME(tom) can't have duplicate params with separate processes
with pytest.raises(ValueError, match="module with duplicate parameters on distinct devices is not supported"):
Pipe(model, [1, 1], style=pipeline_style, worker_map=get_worker_map())
MultiProcessPipe(model, [1, 1], style=pipeline_style, worker_map=get_worker_map())
@torch_spawn([4])
@pytest.mark.parametrize("pipeline_style", [Pipe.MultiProcess, Pipe.AsyncSchedule])
@pytest.mark.parametrize("pipeline_style", [MultiProcessPipe.MultiProcess, MultiProcessPipe.AsyncSchedule])
def pipelined_backward(pipeline_style):
model = nn.Sequential(nn.ReLU(), nn.ReLU())
destroy_model_parallel()
initialize_model_parallel(1, 4)
pipe = Pipe(model, [1, 1], style=pipeline_style, worker_map=get_worker_map())
pipe = MultiProcessPipe(model, [1, 1], style=pipeline_style, worker_map=get_worker_map())
assert pipe.pipelined_backward is False
destroy_model_parallel()
initialize_model_parallel(2, 2)
pipe = Pipe(model, [1, 1], style=pipeline_style, worker_map=get_worker_map())
pipe = MultiProcessPipe(model, [1, 1], style=pipeline_style, worker_map=get_worker_map())
assert pipe.pipelined_backward is True
......@@ -853,7 +855,9 @@ def pipelined_backward(pipeline_style):
def async_event_loop():
model = nn.Sequential(nn.Linear(10, 10), nn.ReLU(), nn.Linear(10, 10), nn.ReLU())
pipe = Pipe(model, [1, 1, 1, 1], style=Pipe.AsyncSchedule, worker_map=get_worker_map(), chunks=10)
pipe = MultiProcessPipe(
model, [1, 1, 1, 1], style=MultiProcessPipe.AsyncSchedule, worker_map=get_worker_map(), chunks=10
)
inputs = torch.rand(100, 10)
......@@ -869,7 +873,7 @@ def reuse_lazy():
reused = LazyModule(lambda: nn.Linear(10, 10))
model = [reused, nn.Linear(10, 10), nn.ReLU(), reused, nn.ReLU(), reused, nn.ReLU()]
# model = [reused, reused, nn.Linear(10, 10), nn.ReLU(), reused, reused, nn.ReLU(), reused, reused, nn.ReLU()]
pipe = Pipe(model, [3, 1, 1], style=Pipe.AsyncSchedule, worker_map=get_worker_map())
pipe = MultiProcessPipe(model, [3, 1, 1], style=MultiProcessPipe.AsyncSchedule, worker_map=get_worker_map())
pipe.eval()
output = pipe(torch.rand(10))
......@@ -887,7 +891,7 @@ def reuse_lazy():
# ensure identical weights but no sharing between model and pipe
reused = nn.Linear(10, 10)
layers = [reused, nn.Linear(10, 10), nn.ReLU(), reused, nn.ReLU(), reused, nn.ReLU()]
pipe = Pipe(layers, [3, 1, 1], style=Pipe.AsyncSchedule, worker_map=get_worker_map())
pipe = MultiProcessPipe(layers, [3, 1, 1], style=MultiProcessPipe.AsyncSchedule, worker_map=get_worker_map())
pipe.eval()
model_optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9)
pipe_optimizer = torch.optim.SGD(pipe.parameters(), lr=0.01, momentum=0.9) if len(list(pipe.parameters())) else None
......@@ -931,7 +935,7 @@ def reuse_lazy():
def test_instantiate_partition():
from fairscale.nn.pipe.async_schedule import Location
from fairscale.nn.pipe.pipe import instantiate_partition
from fairscale.nn.pipe.multiprocess_pipe import instantiate_partition
class FakeGroup:
def __init__(self, rank, size):
......@@ -947,7 +951,7 @@ def test_instantiate_partition():
def check_partitions(model, balance, expected_order, expected_ranks):
"""Check the instantiated model matches expectation of order and rank
model: a list of modules or an nn.Sequential
balance: the balance argument to Pipe
balance: the balance argument to MultiProcessPipe
expected_order: the index of modules in `model` in the order they will
be executed, grouped by nn.Sequential
expected_rank: the rank that each module will be executed on
......@@ -959,7 +963,9 @@ def test_instantiate_partition():
# Collect `Invocation` and `Invocation` -> `ModuleWrapper` mapping from
# instantiated model
for rank in range(len(balance)):
instantiated = instantiate_partition(model, balance, FakeGroup(rank, len(balance)), Pipe.AsyncSchedule)
instantiated = instantiate_partition(
model, balance, FakeGroup(rank, len(balance)), MultiProcessPipe.AsyncSchedule
)
for part in instantiated:
assert isinstance(part.module, nn.Sequential)
for inv in part.invocations:
......
......@@ -21,13 +21,13 @@ import pytest
import torch
from torch import nn
from fairscale.nn import Pipe
from fairscale.nn.pipe import MultiProcessPipe
from fairscale.utils.testing import get_worker_map, set_random_seed, torch_spawn
@torch_spawn([2])
@pytest.mark.skipif(not torch.cuda.is_available(), reason="cuda required")
@pytest.mark.parametrize("pipeline_style", [Pipe.MultiProcess, Pipe.AsyncSchedule])
@pytest.mark.parametrize("pipeline_style", [MultiProcessPipe.MultiProcess, MultiProcessPipe.AsyncSchedule])
def simple_linears(pipeline_style):
def sum_grad(parameters):
return sum([p.grad.sum() for p in parameters if p.grad is not None])
......@@ -40,7 +40,7 @@ def simple_linears(pipeline_style):
inputs = torch.rand(8, 1)
model = nn.Sequential(nn.Linear(1, 2), nn.Linear(2, 4), nn.Linear(4, 2), nn.Linear(2, 1),)
# Without Pipe
# Without MultiProcessPipe
outputs = model(inputs)
loss = outputs.mean()
loss.backward()
......@@ -54,20 +54,20 @@ def simple_linears(pipeline_style):
zero_grad(model.parameters())
# With Pipe
model = Pipe(model, [2, 2], style=pipeline_style, worker_map=get_worker_map(), chunks=4)
# With MultiProcessPipe
model = MultiProcessPipe(model, [2, 2], style=pipeline_style, worker_map=get_worker_map(), chunks=4)
outputs = model(inputs)
if model.group.rank() == 1:
loss = outputs.mean()
loss.backward()
grad_with_pipe = sum_grad(model.pipeline.mp_partitions[0].module.parameters())
grad_with_pipe = sum_grad(model.pipeline.partitions[0].module.parameters())
# Both grads should be identical.
assert torch.allclose(grad_with_pipe, grad_without_pipe[1])
else:
model.back_helper(outputs)
grad_with_pipe = sum_grad(model.pipeline.mp_partitions[0].module.parameters())
grad_with_pipe = sum_grad(model.pipeline.partitions[0].module.parameters())
# Both grads should be identical.
assert torch.allclose(grad_with_pipe, grad_without_pipe[0])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment