Commit 2f25c570 authored by Lawrence McAfee's avatar Lawrence McAfee
Browse files

working: interleaving; free_output_tensor() now handles none/tensor/list

parent 86da10e9
......@@ -42,8 +42,13 @@ def get_forward_backward_func():
forward_backward_func = forward_backward_no_pipelining
return forward_backward_func
def free_output_tensor(t):
t.data = torch.FloatTensor([0]).to(t.data)
def free_output_tensor(output_tensors):
if output_tensors is None:
return
if isinstance(output_tensors, torch.Tensor):
output_tensors = [output_tensors]
for output_tensor in output_tensors:
output_tensor.data = torch.FloatTensor([0]).to(output_tensor.data)
def custom_backward(output, grad_output):
......@@ -354,6 +359,7 @@ def forward_backward_pipelining_with_interleaving(forward_step_func, data_iterat
output_tensor, recv_prev=recv_prev,
tensor_shape=tensor_shape,
timers=timers)
free_output_tensor(output_tensor)
input_tensors[next_forward_model_chunk_id].append(input_tensor)
# Run 1F1B in steady state.
......@@ -418,6 +424,7 @@ def forward_backward_pipelining_with_interleaving(forward_step_func, data_iterat
output_tensor, input_tensor_grad,
recv_prev=recv_prev, recv_next=recv_next,
tensor_shape=tensor_shape, timers=timers)
free_output_tensor(output_tensor)
# Put input_tensor and output_tensor_grad in data structures in the
# right location.
......@@ -590,9 +597,9 @@ def forward_backward_pipelining_without_interleaving(forward_step_func, data_ite
send_forward(output_tensor, send_tensor_shapes, timers=timers)
if not forward_only:
[ free_output_tensor(t) for t in output_tensor ]
input_tensors.append(input_tensor)
output_tensors.append(output_tensor)
free_output_tensor(output_tensor)
# Before running 1F1B, need to receive first forward tensor.
# If all microbatches are run in warmup / cooldown phase, then no need to
......@@ -619,9 +626,9 @@ def forward_backward_pipelining_without_interleaving(forward_step_func, data_ite
timers=timers)
# Add input_tensor and output_tensor to end of list.
[ free_output_tensor(t) for t in output_tensor ]
input_tensors.append(input_tensor)
output_tensors.append(output_tensor)
free_output_tensor(output_tensor)
# Pop input_tensor and output_tensor from the start of the list for
# the backward pass.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment