Unverified Commit d624b81a authored by msbaines's avatar msbaines Committed by GitHub
Browse files

[refactor] multiprocess_pipe: remove retain_graph __init__ param (#358)

It is not currently being used so we can simplify the interface
by removing it.
parent 39675773
......@@ -163,9 +163,6 @@ class MultiProcessPipe(Module):
at the same time. Defaults to `True` if
`get_model_parallel_world_size() > 1`
(default: `None`)
retain_graph (bool):
The value passed to `torch.autograd.backwards(..., retain_graph=<value>)
(default: = `True`)
Raises:
TypeError:
......@@ -216,7 +213,6 @@ class MultiProcessPipe(Module):
checkpoint: str = checkpoint,
deferred_batch_norm: bool = False,
pipelined_backward: bool = None,
retain_graph: bool = False,
) -> None:
super().__init__()
......@@ -237,7 +233,6 @@ class MultiProcessPipe(Module):
self.chunks = chunks
self.checkpoint = checkpoint
self.pipelined_backward = pipelined_backward
self.retain_graph = retain_graph
self.pipeline: Optional[MultiProcessPipeline]
self.lock = threading.Lock()
......@@ -417,7 +412,7 @@ class MultiProcessPipe(Module):
torch.device(torch.cuda.current_device() if torch.cuda.is_available() else "cpu"),
requires_grad=True,
)
output = PipelinedBackwardPass.apply(output, batches, phony, True) # self.retain_graph)
output = PipelinedBackwardPass.apply(output, batches, phony)
else:
output = microbatch.gather(batches)
else:
......@@ -439,9 +434,8 @@ class MultiProcessPipe(Module):
class PipelinedBackwardPass(torch.autograd.Function):
@staticmethod
# type: ignore
def forward(ctx, input: TensorOrTensors, batches, phony, retain_graph) -> TensorOrTensors:
def forward(ctx, input: TensorOrTensors, batches, phony) -> TensorOrTensors:
ctx.batches = batches
ctx.retain_graph = retain_graph
return input
@staticmethod
......@@ -452,7 +446,7 @@ class PipelinedBackwardPass(torch.autograd.Function):
for grad, batch in reversed(list(zip(grad_batches, ctx.batches))):
for t in batch:
t.retain_grad()
torch.autograd.backward(batch.tensor_or_tensors, grad_tensors=(*grad,), retain_graph=ctx.retain_graph)
torch.autograd.backward(batch.tensor_or_tensors, grad_tensors=(*grad,), retain_graph=True)
with torch.no_grad():
if ctx.batches[0].atomic:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment