"git@developer.sourcefind.cn:Wenxuan/LightX2V.git" did not exist on "e53d3cb420868be4c9d281a41f654bba362eca58"
Commit 182841f7 authored by Deepak Narayanan's avatar Deepak Narayanan
Browse files

Make sure pipeline-model-parallel size is greater than 2 for interleaved schedule

parent a6e00d97
...@@ -116,6 +116,9 @@ def parse_args(extra_args_provider=None, defaults={}, ...@@ -116,6 +116,9 @@ def parse_args(extra_args_provider=None, defaults={},
args.global_batch_size), flush=True) args.global_batch_size), flush=True)
assert args.global_batch_size > 0 assert args.global_batch_size > 0
if args.num_layers_per_virtual_pipeline_stage is not None: if args.num_layers_per_virtual_pipeline_stage is not None:
assert args.pipeline_model_parallel_size > 2, \
'pipeline-model-parallel size should be greater than 2 with ' \
'interleaved schedule'
assert args.num_layers % args.num_layers_per_virtual_pipeline_stage == 0, \ assert args.num_layers % args.num_layers_per_virtual_pipeline_stage == 0, \
'number of layers is not divisible by number of layers per virtual ' \ 'number of layers is not divisible by number of layers per virtual ' \
'pipeline stage' 'pipeline stage'
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment