Commit 08ddde0c authored by Deepak Narayanan's avatar Deepak Narayanan
Browse files

Fix deadlock when get_num_microbatches() < pipeline-parallel size (don't try...

Fix deadlock when get_num_microbatches() < pipeline-parallel size (don't try to measure pipeline stall)
parent 854c3409
......@@ -362,6 +362,11 @@ def forward_backward_pipelining_without_interleaving(forward_step_func, data_ite
num_microbatches_remaining = \
num_microbatches - num_warmup_microbatches
# Measure pipeline stall only if there are enough microbatches
# to have every worker in a warmup and steady state phase.
measure_pipeline_stall = get_num_microbatches() >= \
mpu.get_pipeline_model_parallel_world_size()
input_tensors = []
output_tensors = []
losses_reduced = []
......@@ -372,7 +377,7 @@ def forward_backward_pipelining_without_interleaving(forward_step_func, data_ite
output_tensor = forward_step(forward_step_func, data_iterator, model,
input_tensor, losses_reduced)
# Barrier before first receive to measure forward stall.
if i == (num_warmup_microbatches - 1):
if i == (num_warmup_microbatches - 1) and measure_pipeline_stall:
timers('forward-pipeline-stall').start()
torch.distributed.barrier(group=mpu.get_pipeline_model_parallel_group())
timers('forward-pipeline-stall').stop()
......@@ -382,7 +387,7 @@ def forward_backward_pipelining_without_interleaving(forward_step_func, data_ite
output_tensors.append(output_tensor)
# Barrier before first receive to measure forward stall.
if num_warmup_microbatches == 0:
if num_warmup_microbatches == 0 and measure_pipeline_stall:
timers('forward-pipeline-stall').start()
torch.distributed.barrier(group=mpu.get_pipeline_model_parallel_group())
timers('forward-pipeline-stall').stop()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment