Unverified Commit 39675773 authored by msbaines's avatar msbaines Committed by GitHub
Browse files

[refactor] multiprocess_pipe: cleanup __init__ (#357)

parent de713d1e
...@@ -220,9 +220,6 @@ class MultiProcessPipe(Module): ...@@ -220,9 +220,6 @@ class MultiProcessPipe(Module):
) -> None: ) -> None:
super().__init__() super().__init__()
chunks = int(chunks)
checkpoint = str(checkpoint)
if chunks <= 0: if chunks <= 0:
raise ValueError("number of chunks must be positive integer") raise ValueError("number of chunks must be positive integer")
if checkpoint not in ["always", "except_last", "never"]: if checkpoint not in ["always", "except_last", "never"]:
...@@ -259,10 +256,18 @@ class MultiProcessPipe(Module): ...@@ -259,10 +256,18 @@ class MultiProcessPipe(Module):
f" {len(self.balance)})" f" {len(self.balance)})"
) )
if isinstance(module, nn.Sequential):
local_partitions = split_module(module, self.balance)
self._skip_layout = inspect_skip_layout(local_partitions)
else:
self._skip_layout = SkipLayout(len(module), {}) # FIXME(tom)
rank = self.group.rank() rank = self.group.rank()
self.final_stage = rank == len(self.balance) - 1
if rank >= len(self.balance): if rank >= len(self.balance):
warnings.warn("More ranks than partitions, some ranks unused") warnings.warn("More ranks than partitions, some ranks unused")
self.partitions: List[ModuleWrapper] = [] self.partitions: List[ModuleWrapper] = []
self.pipeline = None
else: else:
self.partitions = self.instantiate_partition(module, self.balance, self.group) self.partitions = self.instantiate_partition(module, self.balance, self.group)
if deferred_batch_norm: if deferred_batch_norm:
...@@ -270,21 +275,10 @@ class MultiProcessPipe(Module): ...@@ -270,21 +275,10 @@ class MultiProcessPipe(Module):
part.module = DeferredBatchNorm.convert_deferred_batch_norm(part.module, chunks) part.module = DeferredBatchNorm.convert_deferred_batch_norm(part.module, chunks)
for name, part in enumerate(self.partitions): for name, part in enumerate(self.partitions):
self.add_module(str(name), part.module) self.add_module(str(name), part.module)
if isinstance(module, nn.Sequential): self.create_pipeline()
local_partitions = split_module(module, self.balance)
self._skip_layout = inspect_skip_layout(local_partitions)
else:
self._skip_layout = SkipLayout(len(module), {}) # FIXME(tom)
rank = self.group.rank() del module
if rank >= len(self.balance):
self.pipeline = None
self.final_stage = False
else:
self.final_stage = rank == len(self.balance) - 1
self.create_pipeline()
del module
if self.pipelined_backward is None: if self.pipelined_backward is None:
if get_model_parallel_world_size() > 1: if get_model_parallel_world_size() > 1:
self.pipelined_backward = True self.pipelined_backward = True
......
...@@ -109,16 +109,9 @@ def mpi(): ...@@ -109,16 +109,9 @@ def mpi():
@torch_spawn([1]) @torch_spawn([1])
@pytest.mark.parametrize("pipe_class", [MultiProcessPipe, AsyncPipe]) @pytest.mark.parametrize("pipe_class", [MultiProcessPipe, AsyncPipe])
def public_attrs(pipe_class): def public_attrs(pipe_class):
class MyString:
def __init__(self, value):
self.value = value
def __str__(self):
return self.value
model = nn.Sequential(nn.Linear(1, 1)) model = nn.Sequential(nn.Linear(1, 1))
pipe = pipe_class(model, balance=(1,), worker_map=get_worker_map(), chunks=42.000, checkpoint=MyString("always"),) pipe = pipe_class(model, balance=(1,), worker_map=get_worker_map(), chunks=42, checkpoint="always",)
assert pipe.balance == [1] assert pipe.balance == [1]
assert pipe.chunks == 42 assert pipe.chunks == 42
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment