Commit d10f81c5 authored by Lawrence McAfee's avatar Lawrence McAfee
Browse files

removed uses are args.deallocate_pipeline_output

parent 18846a0a
......@@ -76,7 +76,7 @@ def get_forward_backward_func():
# )
# # <<<
# <<<
def free_output_tensor(out, deallocate_pipeline_outputs):
def free_output_tensor(out):
'''Pseudo-free (i.e., set to scalar) the output tensor's '.data' field.
This method should be called right after the output tensor has been
......@@ -216,14 +216,7 @@ def backward_step(optimizer, input_tensor, output_tensor, output_tensor_grad):
# Backward pass.
if output_tensor_grad[0] is None:
output_tensor = optimizer.scale_loss(output_tensor[0])
if args.deallocate_pipeline_outputs:
# >>>
# pax(4, {"output_tensor": output_tensor})
# <<<
custom_backward(output_tensor[0], output_tensor_grad[0])
else:
torch.autograd.backward(output_tensor[0],
grad_tensors=output_tensor_grad[0])
custom_backward(output_tensor[0], output_tensor_grad[0])
# Collect the grad of the input_tensor.
input_tensor_grad = [None]
......@@ -431,8 +424,11 @@ def forward_backward_pipelining_with_interleaving(forward_step_func, data_iterat
output_tensor, recv_prev=recv_prev,
tensor_shape=tensor_shape,
timers=timers)
free_output_tensor(output_tensor, args.deallocate_pipeline_outputs)
input_tensors[next_forward_model_chunk_id].append(input_tensor)
# >>>
pax({"output_tensor": output_tensor})
# <<<
free_output_tensor(output_tensor)
# Run 1F1B in steady state.
for k in range(num_microbatches_remaining):
......@@ -496,7 +492,7 @@ def forward_backward_pipelining_with_interleaving(forward_step_func, data_iterat
output_tensor, input_tensor_grad,
recv_prev=recv_prev, recv_next=recv_next,
tensor_shape=tensor_shape, timers=timers)
free_output_tensor(output_tensor, args.deallocate_pipeline_outputs)
free_output_tensor(output_tensor)
# Put input_tensor and output_tensor_grad in data structures in the
# right location.
......@@ -672,10 +668,7 @@ def forward_backward_pipelining_without_interleaving(forward_step_func, data_ite
if not forward_only:
input_tensors.append(input_tensor)
output_tensors.append(output_tensor)
# >>>
# pax(2, {"output_tensor": output_tensor})
# <<<
free_output_tensor(output_tensor[0], args.deallocate_pipeline_outputs)
free_output_tensor(output_tensor[0])
# Before running 1F1B, need to receive first forward tensor.
# If all microbatches are run in warmup / cooldown phase, then no need to
......@@ -704,14 +697,7 @@ def forward_backward_pipelining_without_interleaving(forward_step_func, data_ite
# Add input_tensor and output_tensor to end of list.
input_tensors.append(input_tensor)
output_tensors.append(output_tensor)
# >>>
# if torch.distributed.get_rank() == 3:
# pax({"output_tensor": output_tensor})
# <<<
# >>>
# free_output_tensor(output_tensor, args.deallocate_pipeline_outputs)
free_output_tensor(output_tensor[0], args.deallocate_pipeline_outputs)
# <<<
free_output_tensor(output_tensor[0])
# Pop input_tensor and output_tensor from the start of the list for
# the backward pass.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment