Commit 90623142 authored by Jared Casper's avatar Jared Casper
Browse files

Merge branch 'interleaved_schedule_assertion_fix' into 'main'

Fix interleaved schedule assertion

See merge request ADLR/megatron-lm!237
parents 9dc111cc b7067cbd
......@@ -123,9 +123,6 @@ def parse_args(extra_args_provider=None, defaults={},
args.virtual_pipeline_model_parallel_size = \
(args.num_layers // args.pipeline_model_parallel_size) // \
args.num_layers_per_virtual_pipeline_stage
assert args.global_batch_size % args.pipeline_model_parallel_size == 0, \
'global batch size is not divisible by pipeline parallel size when ' \
'using interleaved schedule'
else:
args.virtual_pipeline_model_parallel_size = None
......
......@@ -339,6 +339,9 @@ def train_step(forward_step_func, data_iterator,
if mpu.get_pipeline_model_parallel_world_size() > 1:
if args.virtual_pipeline_model_parallel_size is not None:
forward_backward_func = forward_backward_pipelining_with_interleaving
assert get_num_microbatches() % args.pipeline_model_parallel_size == 0, \
'number of microbatches is not divisible by pipeline-parallel ' \
'size when using interleaved schedule'
else:
forward_backward_func = forward_backward_pipelining_without_interleaving
else:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment