Commit 8e922d5b authored by Deepak Narayanan's avatar Deepak Narayanan
Browse files

Put in barriers in appropriate places to measure length of pipeline stall

parent dd889062
......@@ -95,6 +95,7 @@ def forward_backward_pipelining_with_interleaving(forward_step_func, data_iterat
output_tensor_grads = [[] for _ in range(len(model))]
pipeline_parallel_size = mpu.get_pipeline_model_parallel_world_size()
pipeline_parallel_rank = mpu.get_pipeline_model_parallel_rank()
# Compute number of warmup and remaining microbatches.
num_model_chunks = len(model)
......@@ -108,8 +109,7 @@ def forward_backward_pipelining_with_interleaving(forward_step_func, data_iterat
all_warmup_microbatches = True
else:
num_warmup_microbatches = \
(pipeline_parallel_size -
mpu.get_pipeline_model_parallel_rank() - 1) * 2
(pipeline_parallel_size - pipeline_parallel_rank - 1) * 2
num_warmup_microbatches += (num_model_chunks - 1) * pipeline_parallel_size
num_warmup_microbatches = min(num_warmup_microbatches, num_microbatches)
num_microbatches_remaining = \
......@@ -272,6 +272,8 @@ def forward_backward_pipelining_with_interleaving(forward_step_func, data_iterat
def forward_backward_pipelining(forward_step_func, data_iterator, model,
optimizer, timers, forward_only):
"""Run 1F1B schedule, with communication and warmup + cooldown microbatches as needed."""
timers = get_timers()
assert len(model) == 1
model = model[0]
......@@ -295,11 +297,22 @@ def forward_backward_pipelining(forward_step_func, data_iterator, model,
input_tensor = recv_forward(timers)
output_tensor = forward_step(forward_step_func, data_iterator, model,
input_tensor, losses_reduced)
# Barrier before first receive to measure forward stall.
if i == (num_warmup_microbatches - 1):
timers('forward-pipeline-stall').start()
torch.distributed.barrier(group=mpu.get_pipeline_model_parallel_group())
timers('forward-pipeline-stall').stop()
send_forward(output_tensor, timers)
input_tensors.append(input_tensor)
output_tensors.append(output_tensor)
# Barrier before first receive to measure forward stall.
if num_warmup_microbatches == 0:
timers('forward-pipeline-stall').start()
torch.distributed.barrier(group=mpu.get_pipeline_model_parallel_group())
timers('forward-pipeline-stall').stop()
# Before running 1F1B, need to receive first forward tensor.
# If all microbatches are run in warmup / cooldown phase, then no need to
# receive this tensor here.
......
......@@ -354,6 +354,11 @@ def train_step(forward_step_func, data_iterator,
fp32_allreduce=args.fp32_allreduce)
timers('backward-params-all-reduce').stop()
# Barrier to measure backward stall.
timers('backward-pipeline-stall').start()
torch.distributed.barrier(group=mpu.get_pipeline_model_parallel_group())
timers('backward-pipeline-stall').stop()
# All-reduce word_embeddings' grad across first and last stages to ensure
# that word_embeddings parameters stay in sync.
# This should only run for models that support pipelined model parallelism
......@@ -443,10 +448,12 @@ def training_log(loss_dict, total_loss_dict, learning_rate, iteration,
if name in timers.timers:
timers_to_log.append(name)
add_to_logging('forward-compute')
add_to_logging('forward-pipeline-stall')
add_to_logging('forward-recv')
add_to_logging('forward-send')
add_to_logging('forward-backward-send-forward-backward-recv')
add_to_logging('backward-compute')
add_to_logging('backward-pipeline-stall')
add_to_logging('backward-recv')
add_to_logging('backward-send')
add_to_logging('backward-send-forward-recv')
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment