Commit 0865c4dc authored by Lawrence McAfee's avatar Lawrence McAfee
Browse files

removed saving of input/output tensors for forward-only passes of pipeline schedules

parent 6a680986
...@@ -194,6 +194,7 @@ def forward_backward_pipelining_with_interleaving(forward_step_func, data_iterat ...@@ -194,6 +194,7 @@ def forward_backward_pipelining_with_interleaving(forward_step_func, data_iterat
model_chunk_id = get_model_chunk_id(microbatch_id, forward=True) model_chunk_id = get_model_chunk_id(microbatch_id, forward=True)
mpu.set_virtual_pipeline_model_parallel_rank(model_chunk_id) mpu.set_virtual_pipeline_model_parallel_rank(model_chunk_id)
# forward step
if mpu.is_pipeline_first_stage(): if mpu.is_pipeline_first_stage():
if len(input_tensors[model_chunk_id]) == \ if len(input_tensors[model_chunk_id]) == \
len(output_tensors[model_chunk_id]): len(output_tensors[model_chunk_id]):
...@@ -205,6 +206,11 @@ def forward_backward_pipelining_with_interleaving(forward_step_func, data_iterat ...@@ -205,6 +206,11 @@ def forward_backward_pipelining_with_interleaving(forward_step_func, data_iterat
input_tensor, losses_reduced) input_tensor, losses_reduced)
output_tensors[model_chunk_id].append(output_tensor) output_tensors[model_chunk_id].append(output_tensor)
# if forward-only, no need to save tensors for a backward pass
if forward_only:
input_tensors[model_chunk_id].pop()
output_tensors[model_chunk_id].pop()
return output_tensor return output_tensor
def backward_step_helper(microbatch_id): def backward_step_helper(microbatch_id):
...@@ -383,8 +389,12 @@ def forward_backward_pipelining_without_interleaving(forward_step_func, data_ite ...@@ -383,8 +389,12 @@ def forward_backward_pipelining_without_interleaving(forward_step_func, data_ite
num_microbatches_remaining = \ num_microbatches_remaining = \
num_microbatches - num_warmup_microbatches num_microbatches - num_warmup_microbatches
input_tensors = [] # Input, output tensors only need to be saved when doing backward passes
output_tensors = [] input_tensors = None
output_tensors = None
if not forward_only:
input_tensors = []
output_tensors = []
losses_reduced = [] losses_reduced = []
# Run warmup forward passes. # Run warmup forward passes.
...@@ -394,8 +404,9 @@ def forward_backward_pipelining_without_interleaving(forward_step_func, data_ite ...@@ -394,8 +404,9 @@ def forward_backward_pipelining_without_interleaving(forward_step_func, data_ite
input_tensor, losses_reduced) input_tensor, losses_reduced)
p2p_communication.send_forward(output_tensor, timers=timers) p2p_communication.send_forward(output_tensor, timers=timers)
input_tensors.append(input_tensor) if not forward_only:
output_tensors.append(output_tensor) input_tensors.append(input_tensor)
output_tensors.append(output_tensor)
# Before running 1F1B, need to receive first forward tensor. # Before running 1F1B, need to receive first forward tensor.
# If all microbatches are run in warmup / cooldown phase, then no need to # If all microbatches are run in warmup / cooldown phase, then no need to
...@@ -411,21 +422,23 @@ def forward_backward_pipelining_without_interleaving(forward_step_func, data_ite ...@@ -411,21 +422,23 @@ def forward_backward_pipelining_without_interleaving(forward_step_func, data_ite
input_tensor, losses_reduced) input_tensor, losses_reduced)
if forward_only: if forward_only:
p2p_communication.send_forward(output_tensor, timers=timers) p2p_communication.send_forward(output_tensor, timers=timers)
if not last_iteration:
input_tensor = p2p_communication.recv_forward(timers=timers)
else: else:
output_tensor_grad = \ output_tensor_grad = \
p2p_communication.send_forward_recv_backward(output_tensor, p2p_communication.send_forward_recv_backward(output_tensor,
timers=timers) timers=timers)
# Add input_tensor and output_tensor to end of list, then pop from the # Add input_tensor and output_tensor to end of list.
# start of the list for backward pass. input_tensors.append(input_tensor)
input_tensors.append(input_tensor) output_tensors.append(output_tensor)
output_tensors.append(output_tensor)
if forward_only: # Pop input_tensor and output_tensor from the start of the list for
if not last_iteration: # the backward pass.
input_tensor = p2p_communication.recv_forward(timers=timers) input_tensor = input_tensors.pop(0)
else: output_tensor = output_tensors.pop(0)
input_tensor, output_tensor = input_tensors.pop(0), output_tensors.pop(0)
input_tensor_grad = \ input_tensor_grad = \
backward_step(optimizer, input_tensor, output_tensor, backward_step(optimizer, input_tensor, output_tensor,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment