Unverified Commit d624b81a authored by msbaines's avatar msbaines Committed by GitHub
Browse files

[refactor] multiprocess_pipe: remove retain_graph __init__ param (#358)

It is not currently being used so we can simplify the interface
by removing it.
parent 39675773
...@@ -163,9 +163,6 @@ class MultiProcessPipe(Module): ...@@ -163,9 +163,6 @@ class MultiProcessPipe(Module):
at the same time. Defaults to `True` if at the same time. Defaults to `True` if
`get_model_parallel_world_size() > 1` `get_model_parallel_world_size() > 1`
(default: `None`) (default: `None`)
retain_graph (bool):
The value passed to `torch.autograd.backwards(..., retain_graph=<value>)
(default: = `True`)
Raises: Raises:
TypeError: TypeError:
...@@ -216,7 +213,6 @@ class MultiProcessPipe(Module): ...@@ -216,7 +213,6 @@ class MultiProcessPipe(Module):
checkpoint: str = checkpoint, checkpoint: str = checkpoint,
deferred_batch_norm: bool = False, deferred_batch_norm: bool = False,
pipelined_backward: bool = None, pipelined_backward: bool = None,
retain_graph: bool = False,
) -> None: ) -> None:
super().__init__() super().__init__()
...@@ -237,7 +233,6 @@ class MultiProcessPipe(Module): ...@@ -237,7 +233,6 @@ class MultiProcessPipe(Module):
self.chunks = chunks self.chunks = chunks
self.checkpoint = checkpoint self.checkpoint = checkpoint
self.pipelined_backward = pipelined_backward self.pipelined_backward = pipelined_backward
self.retain_graph = retain_graph
self.pipeline: Optional[MultiProcessPipeline] self.pipeline: Optional[MultiProcessPipeline]
self.lock = threading.Lock() self.lock = threading.Lock()
...@@ -417,7 +412,7 @@ class MultiProcessPipe(Module): ...@@ -417,7 +412,7 @@ class MultiProcessPipe(Module):
torch.device(torch.cuda.current_device() if torch.cuda.is_available() else "cpu"), torch.device(torch.cuda.current_device() if torch.cuda.is_available() else "cpu"),
requires_grad=True, requires_grad=True,
) )
output = PipelinedBackwardPass.apply(output, batches, phony, True) # self.retain_graph) output = PipelinedBackwardPass.apply(output, batches, phony)
else: else:
output = microbatch.gather(batches) output = microbatch.gather(batches)
else: else:
...@@ -439,9 +434,8 @@ class MultiProcessPipe(Module): ...@@ -439,9 +434,8 @@ class MultiProcessPipe(Module):
class PipelinedBackwardPass(torch.autograd.Function): class PipelinedBackwardPass(torch.autograd.Function):
@staticmethod @staticmethod
# type: ignore # type: ignore
def forward(ctx, input: TensorOrTensors, batches, phony, retain_graph) -> TensorOrTensors: def forward(ctx, input: TensorOrTensors, batches, phony) -> TensorOrTensors:
ctx.batches = batches ctx.batches = batches
ctx.retain_graph = retain_graph
return input return input
@staticmethod @staticmethod
...@@ -452,7 +446,7 @@ class PipelinedBackwardPass(torch.autograd.Function): ...@@ -452,7 +446,7 @@ class PipelinedBackwardPass(torch.autograd.Function):
for grad, batch in reversed(list(zip(grad_batches, ctx.batches))): for grad, batch in reversed(list(zip(grad_batches, ctx.batches))):
for t in batch: for t in batch:
t.retain_grad() t.retain_grad()
torch.autograd.backward(batch.tensor_or_tensors, grad_tensors=(*grad,), retain_graph=ctx.retain_graph) torch.autograd.backward(batch.tensor_or_tensors, grad_tensors=(*grad,), retain_graph=True)
with torch.no_grad(): with torch.no_grad():
if ctx.batches[0].atomic: if ctx.batches[0].atomic:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment