Commit b7067cbd authored by Deepak Narayanan's avatar Deepak Narayanan
Browse files

Fix interleaved schedule assertion

parent 9dc111cc
......@@ -123,9 +123,6 @@ def parse_args(extra_args_provider=None, defaults={},
args.virtual_pipeline_model_parallel_size = \
(args.num_layers // args.pipeline_model_parallel_size) // \
args.num_layers_per_virtual_pipeline_stage
assert args.global_batch_size % args.pipeline_model_parallel_size == 0, \
'global batch size is not divisible by pipeline parallel size when ' \
'using interleaved schedule'
else:
args.virtual_pipeline_model_parallel_size = None
......
......@@ -339,6 +339,9 @@ def train_step(forward_step_func, data_iterator,
if mpu.get_pipeline_model_parallel_world_size() > 1:
if args.virtual_pipeline_model_parallel_size is not None:
forward_backward_func = forward_backward_pipelining_with_interleaving
assert get_num_microbatches() % args.pipeline_model_parallel_size == 0, \
'number of microbatches is not divisible by pipeline-parallel ' \
'size when using interleaved schedule'
else:
forward_backward_func = forward_backward_pipelining_without_interleaving
else:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment