Commit 10c6ad06 authored by Jared Casper's avatar Jared Casper
Browse files

Merge branch 'lmcafee/embed-standalone' into 'main'

Standalone embedding stage

See merge request ADLR/megatron-lm!385
parents e724785f 2fadaa50
...@@ -66,6 +66,11 @@ def parse_args(extra_args_provider=None, defaults={}, ...@@ -66,6 +66,11 @@ def parse_args(extra_args_provider=None, defaults={},
args.pipeline_model_parallel_size = min( args.pipeline_model_parallel_size = min(
args.pipeline_model_parallel_size, args.pipeline_model_parallel_size,
(args.world_size // args.tensor_model_parallel_size)) (args.world_size // args.tensor_model_parallel_size))
args.transformer_pipeline_model_parallel_size = (
args.pipeline_model_parallel_size - 1
if args.standalone_embedding_stage else
args.pipeline_model_parallel_size
)
# Checks. # Checks.
model_parallel_size = args.pipeline_model_parallel_size * \ model_parallel_size = args.pipeline_model_parallel_size * \
args.tensor_model_parallel_size args.tensor_model_parallel_size
...@@ -137,7 +142,7 @@ def parse_args(extra_args_provider=None, defaults={}, ...@@ -137,7 +142,7 @@ def parse_args(extra_args_provider=None, defaults={},
'number of layers is not divisible by number of layers per virtual ' \ 'number of layers is not divisible by number of layers per virtual ' \
'pipeline stage' 'pipeline stage'
args.virtual_pipeline_model_parallel_size = \ args.virtual_pipeline_model_parallel_size = \
(args.num_layers // args.pipeline_model_parallel_size) // \ (args.num_layers // args.transformer_pipeline_model_parallel_size) // \
args.num_layers_per_virtual_pipeline_stage args.num_layers_per_virtual_pipeline_stage
else: else:
args.virtual_pipeline_model_parallel_size = None args.virtual_pipeline_model_parallel_size = None
...@@ -700,6 +705,11 @@ def _add_distributed_args(parser): ...@@ -700,6 +705,11 @@ def _add_distributed_args(parser):
help='Call torch.cuda.empty_cache() each iteration ' help='Call torch.cuda.empty_cache() each iteration '
'(training and eval), to reduce fragmentation.' '(training and eval), to reduce fragmentation.'
'0=off, 1=moderate, 2=aggressive.') '0=off, 1=moderate, 2=aggressive.')
group.add_argument('--standalone-embedding-stage', action='store_true',
default=False, help='If set, *input* embedding layer '
'is placed on its own pipeline stage, without any '
'transformer layers. (For T5, this flag currently only '
'affects the encoder embedding.)')
return parser return parser
......
...@@ -579,6 +579,32 @@ class ParallelTransformerLayer(MegatronModule): ...@@ -579,6 +579,32 @@ class ParallelTransformerLayer(MegatronModule):
return output return output
class NoopTransformerLayer(MegatronModule):
"""A single 'no-op' transformer layer.
The sole purpose of this layer is for when a standalone embedding layer
is used (i.e., args.standalone_embedding_stage == True). In this case,
zero transformer layers are assigned when pipeline rank == 0. Additionally,
when virtual pipeline rank >= 1, zero total model parameters are created
(virtual rank 0 contains the input embedding). This results in the model's
input and output tensors being the same, which causes an error when
performing certain memory optimiations on the output tensor (e.g.,
deallocating it). Thus, this layer disconnects the input from the output
via a clone. Since ranks containing a no-op layer are generally under-
utilized (both compute and memory), there's no worry of any performance
degredation.
"""
def __init__(self, layer_number):
super().__init__()
self.layer_number = layer_number
def forward(self, hidden_states, attention_mask,
encoder_output=None, enc_dec_attn_mask=None,
inference_params=None):
return hidden_states.clone()
class ParallelTransformer(MegatronModule): class ParallelTransformer(MegatronModule):
"""Transformer class.""" """Transformer class."""
...@@ -649,8 +675,20 @@ class ParallelTransformer(MegatronModule): ...@@ -649,8 +675,20 @@ class ParallelTransformer(MegatronModule):
else: else:
offset = mpu.get_pipeline_model_parallel_rank() * self.num_layers offset = mpu.get_pipeline_model_parallel_rank() * self.num_layers
self.layers = torch.nn.ModuleList( if self.num_layers == 0:
[build_layer(i + 1 + offset) for i in range(self.num_layers)]) # When a standalone embedding stage is used (e.g.,
# args.standalone_embedding_stage == True), virtual pipeline ranks
# on pipeline rank 0 will have zero transformer layers assigned to
# them. This results in the model's input and output tensors to be
# the same, which will cause failure for certain output tensor
# optimizations (e.g., pipeline output deallocation). To remedy
# this, we assign a 'no-op' layer on these ranks, which will
# disconnect the input tensor from the output tensor.
self.num_layers = 1
self.layers = torch.nn.ModuleList([ NoopTransformerLayer(1) ])
else:
self.layers = torch.nn.ModuleList(
[build_layer(i + 1 + offset) for i in range(self.num_layers)])
if self.post_process: if self.post_process:
# Final layer norm before output. # Final layer norm before output.
...@@ -786,5 +824,5 @@ class ParallelTransformer(MegatronModule): ...@@ -786,5 +824,5 @@ class ParallelTransformer(MegatronModule):
output = self.final_layernorm(hidden_states) output = self.final_layernorm(hidden_states)
else: else:
output = hidden_states output = hidden_states
return output return output
...@@ -323,20 +323,44 @@ def get_num_layers(args, is_encoder_and_decoder_model): ...@@ -323,20 +323,44 @@ def get_num_layers(args, is_encoder_and_decoder_model):
if get_pipeline_model_parallel_world_size() > 1: if get_pipeline_model_parallel_world_size() > 1:
if is_encoder_and_decoder_model: if is_encoder_and_decoder_model:
assert args.pipeline_model_parallel_split_rank is not None assert args.pipeline_model_parallel_split_rank is not None
num_ranks_in_encoder = args.pipeline_model_parallel_split_rank
num_ranks_in_decoder = get_pipeline_model_parallel_world_size() - num_ranks_in_encoder # When a standalone embedding stage is used, a rank is taken from
# the encoder's ranks, to be used for the encoder's embedding
# layer. This way, the rank referenced by the 'split rank' remains
# the same whether or not a standalone embedding stage is used.
num_ranks_in_encoder = (
args.pipeline_model_parallel_split_rank - 1
if args.standalone_embedding_stage else
args.pipeline_model_parallel_split_rank
)
num_ranks_in_decoder = args.transformer_pipeline_model_parallel_size - num_ranks_in_encoder
assert args.num_layers % num_ranks_in_encoder == 0, \ assert args.num_layers % num_ranks_in_encoder == 0, \
'num_layers must be divisible by number of ranks given to encoder' 'num_layers (%d) must be divisible by number of ranks given to encoder (%d)' % (args.num_layers, num_ranks_in_encoder)
assert args.num_layers % num_ranks_in_decoder == 0, \ assert args.num_layers % num_ranks_in_decoder == 0, \
'num_layers must be divisible by number of ranks given to decoder' 'num_layers (%d) must be divisible by number of ranks given to decoder (%d)' % (args.num_layers, num_ranks_in_decoder)
if is_pipeline_stage_before_split(): if is_pipeline_stage_before_split():
num_layers = args.num_layers // num_ranks_in_encoder num_layers = (
0
if args.standalone_embedding_stage
and get_pipeline_model_parallel_rank() == 0 else
args.num_layers // num_ranks_in_encoder
)
else: else:
num_layers = args.num_layers // num_ranks_in_decoder num_layers = args.num_layers // num_ranks_in_decoder
else: else:
assert args.num_layers % get_pipeline_model_parallel_world_size() == 0, \ assert args.num_layers % args.transformer_pipeline_model_parallel_size == 0, \
'num_layers must be divisible by pipeline_model_parallel_size' 'num_layers must be divisible by transformer_pipeline_model_parallel_size'
num_layers = args.num_layers // get_pipeline_model_parallel_world_size()
# When a standalone embedding stage is used, all transformer layers
# are divided among pipeline rank >= 1, while on pipeline rank 0,
# ranks either contain the input embedding layer (virtual pp rank 0),
# or no layers at all (virtual pp rank >= 1).
num_layers = (
0
if args.standalone_embedding_stage
and get_pipeline_model_parallel_rank() == 0 else
args.num_layers // args.transformer_pipeline_model_parallel_size
)
else: else:
num_layers = args.num_layers num_layers = args.num_layers
return num_layers return num_layers
......
...@@ -34,9 +34,13 @@ def get_forward_backward_func(): ...@@ -34,9 +34,13 @@ def get_forward_backward_func():
if mpu.get_pipeline_model_parallel_world_size() > 1: if mpu.get_pipeline_model_parallel_world_size() > 1:
if args.virtual_pipeline_model_parallel_size is not None: if args.virtual_pipeline_model_parallel_size is not None:
forward_backward_func = forward_backward_pipelining_with_interleaving forward_backward_func = forward_backward_pipelining_with_interleaving
assert get_num_microbatches() % args.pipeline_model_parallel_size == 0, \ assert get_num_microbatches() % \
'number of microbatches is not divisible by pipeline-parallel ' \ args.pipeline_model_parallel_size == 0, \
'size when using interleaved schedule' 'number of microbatches (%d) is not divisible by pipeline-' \
'model-parallel-size (%d) when using interleaved schedule' % (
get_num_microbatches(),
args.pipeline_model_parallel_size,
)
else: else:
forward_backward_func = forward_backward_pipelining_without_interleaving forward_backward_func = forward_backward_pipelining_without_interleaving
else: else:
......
...@@ -949,7 +949,6 @@ def build_train_valid_test_data_iterators( ...@@ -949,7 +949,6 @@ def build_train_valid_test_data_iterators(
args.do_valid = flags[1].item() args.do_valid = flags[1].item()
args.do_test = flags[2].item() args.do_test = flags[2].item()
# Build iterators. # Build iterators.
dl_type = args.dataloader_type dl_type = args.dataloader_type
assert dl_type in ['single', 'cyclic'] assert dl_type in ['single', 'cyclic']
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment