Commit 2f25c570 authored by Lawrence McAfee's avatar Lawrence McAfee
Browse files

working: interleaving; free_output_tensor() now handles none/tensor/list

parent 86da10e9
...@@ -42,8 +42,13 @@ def get_forward_backward_func(): ...@@ -42,8 +42,13 @@ def get_forward_backward_func():
forward_backward_func = forward_backward_no_pipelining forward_backward_func = forward_backward_no_pipelining
return forward_backward_func return forward_backward_func
def free_output_tensor(t): def free_output_tensor(output_tensors):
t.data = torch.FloatTensor([0]).to(t.data) if output_tensors is None:
return
if isinstance(output_tensors, torch.Tensor):
output_tensors = [output_tensors]
for output_tensor in output_tensors:
output_tensor.data = torch.FloatTensor([0]).to(output_tensor.data)
def custom_backward(output, grad_output): def custom_backward(output, grad_output):
...@@ -354,6 +359,7 @@ def forward_backward_pipelining_with_interleaving(forward_step_func, data_iterat ...@@ -354,6 +359,7 @@ def forward_backward_pipelining_with_interleaving(forward_step_func, data_iterat
output_tensor, recv_prev=recv_prev, output_tensor, recv_prev=recv_prev,
tensor_shape=tensor_shape, tensor_shape=tensor_shape,
timers=timers) timers=timers)
free_output_tensor(output_tensor)
input_tensors[next_forward_model_chunk_id].append(input_tensor) input_tensors[next_forward_model_chunk_id].append(input_tensor)
# Run 1F1B in steady state. # Run 1F1B in steady state.
...@@ -418,6 +424,7 @@ def forward_backward_pipelining_with_interleaving(forward_step_func, data_iterat ...@@ -418,6 +424,7 @@ def forward_backward_pipelining_with_interleaving(forward_step_func, data_iterat
output_tensor, input_tensor_grad, output_tensor, input_tensor_grad,
recv_prev=recv_prev, recv_next=recv_next, recv_prev=recv_prev, recv_next=recv_next,
tensor_shape=tensor_shape, timers=timers) tensor_shape=tensor_shape, timers=timers)
free_output_tensor(output_tensor)
# Put input_tensor and output_tensor_grad in data structures in the # Put input_tensor and output_tensor_grad in data structures in the
# right location. # right location.
...@@ -590,9 +597,9 @@ def forward_backward_pipelining_without_interleaving(forward_step_func, data_ite ...@@ -590,9 +597,9 @@ def forward_backward_pipelining_without_interleaving(forward_step_func, data_ite
send_forward(output_tensor, send_tensor_shapes, timers=timers) send_forward(output_tensor, send_tensor_shapes, timers=timers)
if not forward_only: if not forward_only:
[ free_output_tensor(t) for t in output_tensor ]
input_tensors.append(input_tensor) input_tensors.append(input_tensor)
output_tensors.append(output_tensor) output_tensors.append(output_tensor)
free_output_tensor(output_tensor)
# Before running 1F1B, need to receive first forward tensor. # Before running 1F1B, need to receive first forward tensor.
# If all microbatches are run in warmup / cooldown phase, then no need to # If all microbatches are run in warmup / cooldown phase, then no need to
...@@ -619,9 +626,9 @@ def forward_backward_pipelining_without_interleaving(forward_step_func, data_ite ...@@ -619,9 +626,9 @@ def forward_backward_pipelining_without_interleaving(forward_step_func, data_ite
timers=timers) timers=timers)
# Add input_tensor and output_tensor to end of list. # Add input_tensor and output_tensor to end of list.
[ free_output_tensor(t) for t in output_tensor ]
input_tensors.append(input_tensor) input_tensors.append(input_tensor)
output_tensors.append(output_tensor) output_tensors.append(output_tensor)
free_output_tensor(output_tensor)
# Pop input_tensor and output_tensor from the start of the list for # Pop input_tensor and output_tensor from the start of the list for
# the backward pass. # the backward pass.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment