Commit 804ed2e6 authored by Lawrence McAfee's avatar Lawrence McAfee
Browse files

working with interleaving

parent a06af061
...@@ -147,7 +147,7 @@ def parse_args(extra_args_provider=None, defaults={}, ...@@ -147,7 +147,7 @@ def parse_args(extra_args_provider=None, defaults={},
# args.num_layers_per_virtual_pipeline_stage # args.num_layers_per_virtual_pipeline_stage
# <<< # <<<
args.virtual_pipeline_model_parallel_size = \ args.virtual_pipeline_model_parallel_size = \
(args.num_layers // transformer_pipeline_size) // \ (args.num_layers // args.transformer_pipeline_model_parallel_size) // \
args.num_layers_per_virtual_pipeline_stage args.num_layers_per_virtual_pipeline_stage
# >>> # >>>
# from lutil import pax # from lutil import pax
......
...@@ -350,13 +350,13 @@ def get_num_layers(args, is_encoder_and_decoder_model): ...@@ -350,13 +350,13 @@ def get_num_layers(args, is_encoder_and_decoder_model):
# get_pipeline_model_parallel_world_size() # get_pipeline_model_parallel_world_size()
# ) # )
# <<< # <<<
assert args.num_layers % transformer_pipeline_size == 0, \ assert args.num_layers % args.transformer_pipeline_model_parallel_size == 0, \
'num_layers must be divisible by transformer_pipeline_size' 'num_layers must be divisible by transformer_pipeline_model_parallel_size'
num_layers = ( num_layers = (
0 0
if args.standalone_embed_stage if args.standalone_embed_stage
and get_pipeline_model_parallel_rank() == 0 else and get_pipeline_model_parallel_rank() == 0 else
args.num_layers // transformer_pipeline_size args.num_layers // args.transformer_pipeline_model_parallel_size
) )
else: else:
num_layers = args.num_layers num_layers = args.num_layers
......
...@@ -40,9 +40,26 @@ def get_forward_backward_func(): ...@@ -40,9 +40,26 @@ def get_forward_backward_func():
# "pipeline size" : args.pipeline_model_parallel_size, # "pipeline size" : args.pipeline_model_parallel_size,
# }) # })
# <<< # <<<
assert get_num_microbatches() % args.pipeline_model_parallel_size == 0, \ # >>>
'number of microbatches is not divisible by pipeline-parallel ' \ # assert get_num_microbatches() % args.pipeline_model_parallel_size == 0, \
'size when using interleaved schedule' # 'number of microbatches is not divisible by pipeline-parallel ' \
# 'size when using interleaved schedule'
# assert get_num_microbatches() % \
# args.transformer_pipeline_model_parallel_size == 0, \
# 'number of microbatches (%d) is not divisible by transformer-' \
# 'pipeline-model-parallel-size (%d) when using interleaved ' \
# 'schedule' % (
# get_num_microbatches(),
# args.transformer_pipeline_model_parallel_size,
# )
assert get_num_microbatches() % \
args.pipeline_model_parallel_size == 0, \
'number of microbatches (%d) is not divisible by pipeline-' \
'model-parallel-size (%d) when using interleaved schedule' % (
get_num_microbatches(),
args.pipeline_model_parallel_size,
)
# <<<
else: else:
forward_backward_func = forward_backward_pipelining_without_interleaving forward_backward_func = forward_backward_pipelining_without_interleaving
else: else:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment