Unverified Commit 39675773 authored by msbaines's avatar msbaines Committed by GitHub
Browse files

[refactor] multiprocess_pipe: cleanup __init__ (#357)

parent de713d1e
......@@ -220,9 +220,6 @@ class MultiProcessPipe(Module):
) -> None:
super().__init__()
chunks = int(chunks)
checkpoint = str(checkpoint)
if chunks <= 0:
raise ValueError("number of chunks must be positive integer")
if checkpoint not in ["always", "except_last", "never"]:
......@@ -259,10 +256,18 @@ class MultiProcessPipe(Module):
f" {len(self.balance)})"
)
if isinstance(module, nn.Sequential):
local_partitions = split_module(module, self.balance)
self._skip_layout = inspect_skip_layout(local_partitions)
else:
self._skip_layout = SkipLayout(len(module), {}) # FIXME(tom)
rank = self.group.rank()
self.final_stage = rank == len(self.balance) - 1
if rank >= len(self.balance):
warnings.warn("More ranks than partitions, some ranks unused")
self.partitions: List[ModuleWrapper] = []
self.pipeline = None
else:
self.partitions = self.instantiate_partition(module, self.balance, self.group)
if deferred_batch_norm:
......@@ -270,21 +275,10 @@ class MultiProcessPipe(Module):
part.module = DeferredBatchNorm.convert_deferred_batch_norm(part.module, chunks)
for name, part in enumerate(self.partitions):
self.add_module(str(name), part.module)
if isinstance(module, nn.Sequential):
local_partitions = split_module(module, self.balance)
self._skip_layout = inspect_skip_layout(local_partitions)
else:
self._skip_layout = SkipLayout(len(module), {}) # FIXME(tom)
self.create_pipeline()
rank = self.group.rank()
if rank >= len(self.balance):
self.pipeline = None
self.final_stage = False
else:
self.final_stage = rank == len(self.balance) - 1
del module
self.create_pipeline()
del module
if self.pipelined_backward is None:
if get_model_parallel_world_size() > 1:
self.pipelined_backward = True
......
......@@ -109,16 +109,9 @@ def mpi():
@torch_spawn([1])
@pytest.mark.parametrize("pipe_class", [MultiProcessPipe, AsyncPipe])
def public_attrs(pipe_class):
class MyString:
def __init__(self, value):
self.value = value
def __str__(self):
return self.value
model = nn.Sequential(nn.Linear(1, 1))
pipe = pipe_class(model, balance=(1,), worker_map=get_worker_map(), chunks=42.000, checkpoint=MyString("always"),)
pipe = pipe_class(model, balance=(1,), worker_map=get_worker_map(), chunks=42, checkpoint="always",)
assert pipe.balance == [1]
assert pipe.chunks == 42
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment