Commit 3cbf7547 authored by Jared Casper's avatar Jared Casper
Browse files

Merge branch 'deadlock_fix' into 'main'

Fix deadlock when get_num_microbatches() < pipeline-parallel size

See merge request ADLR/megatron-lm!243
parents 854c3409 08ddde0c
...@@ -362,6 +362,11 @@ def forward_backward_pipelining_without_interleaving(forward_step_func, data_ite ...@@ -362,6 +362,11 @@ def forward_backward_pipelining_without_interleaving(forward_step_func, data_ite
num_microbatches_remaining = \ num_microbatches_remaining = \
num_microbatches - num_warmup_microbatches num_microbatches - num_warmup_microbatches
# Measure pipeline stall only if there are enough microbatches
# to have every worker in a warmup and steady state phase.
measure_pipeline_stall = get_num_microbatches() >= \
mpu.get_pipeline_model_parallel_world_size()
input_tensors = [] input_tensors = []
output_tensors = [] output_tensors = []
losses_reduced = [] losses_reduced = []
...@@ -372,7 +377,7 @@ def forward_backward_pipelining_without_interleaving(forward_step_func, data_ite ...@@ -372,7 +377,7 @@ def forward_backward_pipelining_without_interleaving(forward_step_func, data_ite
output_tensor = forward_step(forward_step_func, data_iterator, model, output_tensor = forward_step(forward_step_func, data_iterator, model,
input_tensor, losses_reduced) input_tensor, losses_reduced)
# Barrier before first receive to measure forward stall. # Barrier before first receive to measure forward stall.
if i == (num_warmup_microbatches - 1): if i == (num_warmup_microbatches - 1) and measure_pipeline_stall:
timers('forward-pipeline-stall').start() timers('forward-pipeline-stall').start()
torch.distributed.barrier(group=mpu.get_pipeline_model_parallel_group()) torch.distributed.barrier(group=mpu.get_pipeline_model_parallel_group())
timers('forward-pipeline-stall').stop() timers('forward-pipeline-stall').stop()
...@@ -382,7 +387,7 @@ def forward_backward_pipelining_without_interleaving(forward_step_func, data_ite ...@@ -382,7 +387,7 @@ def forward_backward_pipelining_without_interleaving(forward_step_func, data_ite
output_tensors.append(output_tensor) output_tensors.append(output_tensor)
# Barrier before first receive to measure forward stall. # Barrier before first receive to measure forward stall.
if num_warmup_microbatches == 0: if num_warmup_microbatches == 0 and measure_pipeline_stall:
timers('forward-pipeline-stall').start() timers('forward-pipeline-stall').start()
torch.distributed.barrier(group=mpu.get_pipeline_model_parallel_group()) torch.distributed.barrier(group=mpu.get_pipeline_model_parallel_group())
timers('forward-pipeline-stall').stop() timers('forward-pipeline-stall').stop()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment