Commit 08ddde0c authored by Deepak Narayanan's avatar Deepak Narayanan
Browse files

Fix deadlock when get_num_microbatches() < pipeline-parallel size (don't try...

Fix deadlock when get_num_microbatches() < pipeline-parallel size (don't try to measure pipeline stall)
parent 854c3409
...@@ -362,6 +362,11 @@ def forward_backward_pipelining_without_interleaving(forward_step_func, data_ite ...@@ -362,6 +362,11 @@ def forward_backward_pipelining_without_interleaving(forward_step_func, data_ite
num_microbatches_remaining = \ num_microbatches_remaining = \
num_microbatches - num_warmup_microbatches num_microbatches - num_warmup_microbatches
# Measure pipeline stall only if there are enough microbatches
# to have every worker in a warmup and steady state phase.
measure_pipeline_stall = get_num_microbatches() >= \
mpu.get_pipeline_model_parallel_world_size()
input_tensors = [] input_tensors = []
output_tensors = [] output_tensors = []
losses_reduced = [] losses_reduced = []
...@@ -372,7 +377,7 @@ def forward_backward_pipelining_without_interleaving(forward_step_func, data_ite ...@@ -372,7 +377,7 @@ def forward_backward_pipelining_without_interleaving(forward_step_func, data_ite
output_tensor = forward_step(forward_step_func, data_iterator, model, output_tensor = forward_step(forward_step_func, data_iterator, model,
input_tensor, losses_reduced) input_tensor, losses_reduced)
# Barrier before first receive to measure forward stall. # Barrier before first receive to measure forward stall.
if i == (num_warmup_microbatches - 1): if i == (num_warmup_microbatches - 1) and measure_pipeline_stall:
timers('forward-pipeline-stall').start() timers('forward-pipeline-stall').start()
torch.distributed.barrier(group=mpu.get_pipeline_model_parallel_group()) torch.distributed.barrier(group=mpu.get_pipeline_model_parallel_group())
timers('forward-pipeline-stall').stop() timers('forward-pipeline-stall').stop()
...@@ -382,7 +387,7 @@ def forward_backward_pipelining_without_interleaving(forward_step_func, data_ite ...@@ -382,7 +387,7 @@ def forward_backward_pipelining_without_interleaving(forward_step_func, data_ite
output_tensors.append(output_tensor) output_tensors.append(output_tensor)
# Barrier before first receive to measure forward stall. # Barrier before first receive to measure forward stall.
if num_warmup_microbatches == 0: if num_warmup_microbatches == 0 and measure_pipeline_stall:
timers('forward-pipeline-stall').start() timers('forward-pipeline-stall').start()
torch.distributed.barrier(group=mpu.get_pipeline_model_parallel_group()) torch.distributed.barrier(group=mpu.get_pipeline_model_parallel_group())
timers('forward-pipeline-stall').stop() timers('forward-pipeline-stall').stop()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment