Commit dcef9069 authored by Deepak Narayanan's avatar Deepak Narayanan
Browse files

Change argument to control the number of model chunks in a stage

parent 5489bda9
...@@ -116,10 +116,18 @@ def parse_args(extra_args_provider=None, defaults={}, ...@@ -116,10 +116,18 @@ def parse_args(extra_args_provider=None, defaults={},
print('setting global batch size to {}'.format( print('setting global batch size to {}'.format(
args.global_batch_size), flush=True) args.global_batch_size), flush=True)
assert args.global_batch_size > 0 assert args.global_batch_size > 0
if args.virtual_pipeline_model_parallel_size is not None: if args.num_layers_per_virtual_pipeline_stage is not None:
assert args.num_layers % args.num_layers_per_virtual_pipeline_stage == 0, \
'number of layers is not divisible by number of layers per virtual ' \
'pipeline stage'
args.virtual_pipeline_model_parallel_size = \
(args.num_layers // args.pipeline_model_parallel_size) // \
args.num_layers_per_virtual_pipeline_stage
assert args.global_batch_size % args.pipeline_model_parallel_size == 0, \ assert args.global_batch_size % args.pipeline_model_parallel_size == 0, \
'global batch size is not divisible by pipeline parallel size when '\ 'global batch size is not divisible by pipeline parallel size when ' \
'using interleaved schedule' 'using interleaved schedule'
else:
args.virtual_pipeline_model_parallel_size = None
# Parameters dtype. # Parameters dtype.
args.params_dtype = torch.float args.params_dtype = torch.float
...@@ -561,8 +569,8 @@ def _add_distributed_args(parser): ...@@ -561,8 +569,8 @@ def _add_distributed_args(parser):
group.add_argument('--model-parallel-size', type=int, default=None, group.add_argument('--model-parallel-size', type=int, default=None,
help='Old model parallel argument, do not use. Use ' help='Old model parallel argument, do not use. Use '
'--tensor-model-parallel-size instead.') '--tensor-model-parallel-size instead.')
group.add_argument('--virtual-pipeline-model-parallel-size', type=int, default=None, group.add_argument('--num-layers-per-virtual-pipeline-stage', type=int, default=None,
help='Number of virtual pipeline stages in physical stage.') help='Number of layers per virtual pipeline stage')
group.add_argument('--distributed-backend', default='nccl', group.add_argument('--distributed-backend', default='nccl',
choices=['nccl', 'gloo'], choices=['nccl', 'gloo'],
help='Which backend to use for distributed training.') help='Which backend to use for distributed training.')
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment