Commit 8e922d5b authored by Deepak Narayanan's avatar Deepak Narayanan
Browse files

Put in barriers in appropriate places to measure length of pipeline stall

parent dd889062
...@@ -95,6 +95,7 @@ def forward_backward_pipelining_with_interleaving(forward_step_func, data_iterat ...@@ -95,6 +95,7 @@ def forward_backward_pipelining_with_interleaving(forward_step_func, data_iterat
output_tensor_grads = [[] for _ in range(len(model))] output_tensor_grads = [[] for _ in range(len(model))]
pipeline_parallel_size = mpu.get_pipeline_model_parallel_world_size() pipeline_parallel_size = mpu.get_pipeline_model_parallel_world_size()
pipeline_parallel_rank = mpu.get_pipeline_model_parallel_rank()
# Compute number of warmup and remaining microbatches. # Compute number of warmup and remaining microbatches.
num_model_chunks = len(model) num_model_chunks = len(model)
...@@ -108,8 +109,7 @@ def forward_backward_pipelining_with_interleaving(forward_step_func, data_iterat ...@@ -108,8 +109,7 @@ def forward_backward_pipelining_with_interleaving(forward_step_func, data_iterat
all_warmup_microbatches = True all_warmup_microbatches = True
else: else:
num_warmup_microbatches = \ num_warmup_microbatches = \
(pipeline_parallel_size - (pipeline_parallel_size - pipeline_parallel_rank - 1) * 2
mpu.get_pipeline_model_parallel_rank() - 1) * 2
num_warmup_microbatches += (num_model_chunks - 1) * pipeline_parallel_size num_warmup_microbatches += (num_model_chunks - 1) * pipeline_parallel_size
num_warmup_microbatches = min(num_warmup_microbatches, num_microbatches) num_warmup_microbatches = min(num_warmup_microbatches, num_microbatches)
num_microbatches_remaining = \ num_microbatches_remaining = \
...@@ -272,6 +272,8 @@ def forward_backward_pipelining_with_interleaving(forward_step_func, data_iterat ...@@ -272,6 +272,8 @@ def forward_backward_pipelining_with_interleaving(forward_step_func, data_iterat
def forward_backward_pipelining(forward_step_func, data_iterator, model, def forward_backward_pipelining(forward_step_func, data_iterator, model,
optimizer, timers, forward_only): optimizer, timers, forward_only):
"""Run 1F1B schedule, with communication and warmup + cooldown microbatches as needed.""" """Run 1F1B schedule, with communication and warmup + cooldown microbatches as needed."""
timers = get_timers()
assert len(model) == 1 assert len(model) == 1
model = model[0] model = model[0]
...@@ -295,11 +297,22 @@ def forward_backward_pipelining(forward_step_func, data_iterator, model, ...@@ -295,11 +297,22 @@ def forward_backward_pipelining(forward_step_func, data_iterator, model,
input_tensor = recv_forward(timers) input_tensor = recv_forward(timers)
output_tensor = forward_step(forward_step_func, data_iterator, model, output_tensor = forward_step(forward_step_func, data_iterator, model,
input_tensor, losses_reduced) input_tensor, losses_reduced)
# Barrier before first receive to measure forward stall.
if i == (num_warmup_microbatches - 1):
timers('forward-pipeline-stall').start()
torch.distributed.barrier(group=mpu.get_pipeline_model_parallel_group())
timers('forward-pipeline-stall').stop()
send_forward(output_tensor, timers) send_forward(output_tensor, timers)
input_tensors.append(input_tensor) input_tensors.append(input_tensor)
output_tensors.append(output_tensor) output_tensors.append(output_tensor)
# Barrier before first receive to measure forward stall.
if num_warmup_microbatches == 0:
timers('forward-pipeline-stall').start()
torch.distributed.barrier(group=mpu.get_pipeline_model_parallel_group())
timers('forward-pipeline-stall').stop()
# Before running 1F1B, need to receive first forward tensor. # Before running 1F1B, need to receive first forward tensor.
# If all microbatches are run in warmup / cooldown phase, then no need to # If all microbatches are run in warmup / cooldown phase, then no need to
# receive this tensor here. # receive this tensor here.
......
...@@ -354,6 +354,11 @@ def train_step(forward_step_func, data_iterator, ...@@ -354,6 +354,11 @@ def train_step(forward_step_func, data_iterator,
fp32_allreduce=args.fp32_allreduce) fp32_allreduce=args.fp32_allreduce)
timers('backward-params-all-reduce').stop() timers('backward-params-all-reduce').stop()
# Barrier to measure backward stall.
timers('backward-pipeline-stall').start()
torch.distributed.barrier(group=mpu.get_pipeline_model_parallel_group())
timers('backward-pipeline-stall').stop()
# All-reduce word_embeddings' grad across first and last stages to ensure # All-reduce word_embeddings' grad across first and last stages to ensure
# that word_embeddings parameters stay in sync. # that word_embeddings parameters stay in sync.
# This should only run for models that support pipelined model parallelism # This should only run for models that support pipelined model parallelism
...@@ -443,10 +448,12 @@ def training_log(loss_dict, total_loss_dict, learning_rate, iteration, ...@@ -443,10 +448,12 @@ def training_log(loss_dict, total_loss_dict, learning_rate, iteration,
if name in timers.timers: if name in timers.timers:
timers_to_log.append(name) timers_to_log.append(name)
add_to_logging('forward-compute') add_to_logging('forward-compute')
add_to_logging('forward-pipeline-stall')
add_to_logging('forward-recv') add_to_logging('forward-recv')
add_to_logging('forward-send') add_to_logging('forward-send')
add_to_logging('forward-backward-send-forward-backward-recv') add_to_logging('forward-backward-send-forward-backward-recv')
add_to_logging('backward-compute') add_to_logging('backward-compute')
add_to_logging('backward-pipeline-stall')
add_to_logging('backward-recv') add_to_logging('backward-recv')
add_to_logging('backward-send') add_to_logging('backward-send')
add_to_logging('backward-send-forward-recv') add_to_logging('backward-send-forward-recv')
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment