Commit de1ae5b2 authored by Lawrence McAfee's avatar Lawrence McAfee
Browse files

added flag to control deallocation of pipeline outputs

parent 2de7ae27
......@@ -681,6 +681,9 @@ def _add_distributed_args(parser):
help='Call torch.cuda.empty_cache() each iteration '
'(training and eval), to reduce fragmentation.'
'0=off, 1=moderate, 2=aggressive.')
group.add_argument('--deallocate-pipeline-outputs', action='store_true',
default=False, help='If set, pipeline output tensors '
'are deallocated during the forward pass.')
return parser
......
......@@ -42,14 +42,14 @@ def get_forward_backward_func():
forward_backward_func = forward_backward_no_pipelining
return forward_backward_func
def free_output_tensor(output_tensors):
def free_output_tensor(output_tensors, deallocate_pipeline_outputs):
'''Pseudo-free (i.e., set to scalar) the output tensor's '.data' field.
This method should be called right after the output tensor has been
sent to the next pipeline stage. At this point, the output tensor is
only useful for its '.grad_fn' field, and not its '.data'.
'''
if output_tensors is None:
if not deallocate_pipeline_outputs or output_tensors is None:
return
if isinstance(output_tensors, torch.Tensor):
output_tensors = [output_tensors]
......@@ -164,7 +164,11 @@ def backward_step(optimizer, input_tensor, output_tensor, output_tensor_grad):
# Backward pass.
if output_tensor_grad[0] is None:
output_tensor = optimizer.scale_loss(output_tensor[0])
custom_backward(output_tensor[0], output_tensor_grad[0])
if args.deallocate_pipeline_outputs:
custom_backward(output_tensor[0], output_tensor_grad[0])
else:
torch.autograd.backward(output_tensor[0],
grad_tensors=output_tensor_grad[0])
# Collect the grad of the input_tensor.
input_tensor_grad = [None]
......@@ -372,7 +376,7 @@ def forward_backward_pipelining_with_interleaving(forward_step_func, data_iterat
output_tensor, recv_prev=recv_prev,
tensor_shape=tensor_shape,
timers=timers)
free_output_tensor(output_tensor)
free_output_tensor(output_tensor, args.deallocate_pipeline_outputs)
input_tensors[next_forward_model_chunk_id].append(input_tensor)
# Run 1F1B in steady state.
......@@ -437,7 +441,7 @@ def forward_backward_pipelining_with_interleaving(forward_step_func, data_iterat
output_tensor, input_tensor_grad,
recv_prev=recv_prev, recv_next=recv_next,
tensor_shape=tensor_shape, timers=timers)
free_output_tensor(output_tensor)
free_output_tensor(output_tensor, args.deallocate_pipeline_outputs)
# Put input_tensor and output_tensor_grad in data structures in the
# right location.
......@@ -571,6 +575,7 @@ def forward_backward_pipelining_without_interleaving(forward_step_func, data_ite
stages.
Returns dictionary with losses if the last stage, empty dict otherwise."""
args = get_args()
timers = get_timers()
assert len(model) == 1
......@@ -612,7 +617,7 @@ def forward_backward_pipelining_without_interleaving(forward_step_func, data_ite
if not forward_only:
input_tensors.append(input_tensor)
output_tensors.append(output_tensor)
free_output_tensor(output_tensor)
free_output_tensor(output_tensor, args.deallocate_pipeline_outputs)
# Before running 1F1B, need to receive first forward tensor.
# If all microbatches are run in warmup / cooldown phase, then no need to
......@@ -641,7 +646,7 @@ def forward_backward_pipelining_without_interleaving(forward_step_func, data_ite
# Add input_tensor and output_tensor to end of list.
input_tensors.append(input_tensor)
output_tensors.append(output_tensor)
free_output_tensor(output_tensor)
free_output_tensor(output_tensor, args.deallocate_pipeline_outputs)
# Pop input_tensor and output_tensor from the start of the list for
# the backward pass.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment