Commit 46c74b4c authored by Deepak Narayanan's avatar Deepak Narayanan
Browse files

Support for pipeline parallelism in T5 model

- Accumulate encoder hidden state gradient to handle skip connection
- Correctly compute the number of layers in encoder / decoder for T5 model
- Ensure e weights are initialized the same way in embeddings
- Synchronize embedding gradients across encoder and decoder for T5 model
- Support for checkpoint loading and saving
parent 6a680986
...@@ -80,6 +80,12 @@ def parse_args(extra_args_provider=None, defaults={}, ...@@ -80,6 +80,12 @@ def parse_args(extra_args_provider=None, defaults={},
args.world_size, args.data_parallel_size, args.world_size, args.data_parallel_size,
args.tensor_model_parallel_size, args.tensor_model_parallel_size,
args.pipeline_model_parallel_size), flush=True) args.pipeline_model_parallel_size), flush=True)
if args.pipeline_model_parallel_size > 1:
if args.pipeline_model_parallel_split_rank is not None:
assert args.pipeline_model_parallel_split_rank < \
args.pipeline_model_parallel_size, 'split rank needs'\
' to be less than pipeline model parallel size ({})'.format(
args.pipeline_model_parallel_size)
# Deprecated arguments # Deprecated arguments
assert args.batch_size is None, '--batch-size argument is no longer ' \ assert args.batch_size is None, '--batch-size argument is no longer ' \
...@@ -567,6 +573,9 @@ def _add_distributed_args(parser): ...@@ -567,6 +573,9 @@ def _add_distributed_args(parser):
help='Degree of tensor model parallelism.') help='Degree of tensor model parallelism.')
group.add_argument('--pipeline-model-parallel-size', type=int, default=1, group.add_argument('--pipeline-model-parallel-size', type=int, default=1,
help='Degree of pipeline model parallelism.') help='Degree of pipeline model parallelism.')
group.add_argument('--pipeline-model-parallel-split-rank',
type=int, default=None,
help='Rank where encoder and decoder should be split.')
group.add_argument('--model-parallel-size', type=int, default=None, group.add_argument('--model-parallel-size', type=int, default=None,
help='Old model parallel argument, do not use. Use ' help='Old model parallel argument, do not use. Use '
'--tensor-model-parallel-size instead.') '--tensor-model-parallel-size instead.')
......
...@@ -193,7 +193,8 @@ def _initialize_distributed(): ...@@ -193,7 +193,8 @@ def _initialize_distributed():
else: else:
mpu.initialize_model_parallel(args.tensor_model_parallel_size, mpu.initialize_model_parallel(args.tensor_model_parallel_size,
args.pipeline_model_parallel_size, args.pipeline_model_parallel_size,
args.virtual_pipeline_model_parallel_size) args.virtual_pipeline_model_parallel_size,
args.pipeline_model_parallel_split_rank)
def _init_autoresume(): def _init_autoresume():
......
...@@ -21,3 +21,4 @@ from .gpt_model import GPTModel ...@@ -21,3 +21,4 @@ from .gpt_model import GPTModel
from .t5_model import T5Model from .t5_model import T5Model
from .language_model import get_language_model from .language_model import get_language_model
from .module import Float16Module from .module import Float16Module
from .enums import ModelType
...@@ -15,6 +15,10 @@ ...@@ -15,6 +15,10 @@
import enum import enum
class ModelType(enum.Enum):
encoder_or_decoder = 1
encoder_and_decoder = 2
class LayerType(enum.Enum): class LayerType(enum.Enum):
encoder = 1 encoder = 1
decoder = 2 decoder = 2
......
...@@ -45,7 +45,8 @@ def parallel_lm_logits(input_, word_embeddings_weight, parallel_output, ...@@ -45,7 +45,8 @@ def parallel_lm_logits(input_, word_embeddings_weight, parallel_output,
def get_language_model(num_tokentypes, add_pooler, def get_language_model(num_tokentypes, add_pooler,
encoder_attn_mask_type, init_method=None, encoder_attn_mask_type, init_method=None,
scaled_init_method=None, add_decoder=False, scaled_init_method=None, add_encoder=True,
add_decoder=False,
decoder_attn_mask_type=AttnMaskType.causal, decoder_attn_mask_type=AttnMaskType.causal,
pre_process=True, post_process=True): pre_process=True, post_process=True):
"""Build language model and return along with the key to save.""" """Build language model and return along with the key to save."""
...@@ -64,6 +65,7 @@ def get_language_model(num_tokentypes, add_pooler, ...@@ -64,6 +65,7 @@ def get_language_model(num_tokentypes, add_pooler,
scaled_init_method, scaled_init_method,
encoder_attn_mask_type, encoder_attn_mask_type,
num_tokentypes=num_tokentypes, num_tokentypes=num_tokentypes,
add_encoder=add_encoder,
add_decoder=add_decoder, add_decoder=add_decoder,
decoder_attn_mask_type=decoder_attn_mask_type, decoder_attn_mask_type=decoder_attn_mask_type,
add_pooler=add_pooler, add_pooler=add_pooler,
...@@ -159,6 +161,13 @@ class Embedding(MegatronModule): ...@@ -159,6 +161,13 @@ class Embedding(MegatronModule):
# Embeddings dropout # Embeddings dropout
self.embedding_dropout = torch.nn.Dropout(embedding_dropout_prob) self.embedding_dropout = torch.nn.Dropout(embedding_dropout_prob)
def zero_parameters(self):
"""Zero out all parameters in embedding."""
self.word_embeddings.weight.data.fill_(0)
self.position_embeddings.weight.data.fill_(0)
if self.num_tokentypes > 0:
self.tokentype_embeddings.weight.data.fill_(0)
def add_tokentype_embeddings(self, num_tokentypes): def add_tokentype_embeddings(self, num_tokentypes):
"""Add token-type embedding. This function is provided so we can add """Add token-type embedding. This function is provided so we can add
token-type embeddings in case the pretrained model does not have it. token-type embeddings in case the pretrained model does not have it.
...@@ -273,6 +282,7 @@ class TransformerLanguageModel(MegatronModule): ...@@ -273,6 +282,7 @@ class TransformerLanguageModel(MegatronModule):
output_layer_init_method, output_layer_init_method,
encoder_attn_mask_type, encoder_attn_mask_type,
num_tokentypes=0, num_tokentypes=0,
add_encoder=True,
add_decoder=False, add_decoder=False,
decoder_attn_mask_type=AttnMaskType.causal, decoder_attn_mask_type=AttnMaskType.causal,
add_pooler=False, add_pooler=False,
...@@ -286,10 +296,12 @@ class TransformerLanguageModel(MegatronModule): ...@@ -286,10 +296,12 @@ class TransformerLanguageModel(MegatronModule):
self.hidden_size = args.hidden_size self.hidden_size = args.hidden_size
self.num_tokentypes = num_tokentypes self.num_tokentypes = num_tokentypes
self.init_method = init_method self.init_method = init_method
self.add_encoder = add_encoder
self.encoder_attn_mask_type = encoder_attn_mask_type self.encoder_attn_mask_type = encoder_attn_mask_type
self.add_decoder = add_decoder self.add_decoder = add_decoder
self.decoder_attn_mask_type = decoder_attn_mask_type self.decoder_attn_mask_type = decoder_attn_mask_type
self.add_pooler = add_pooler self.add_pooler = add_pooler
self.encoder_hidden_state = None
# Embeddings. # Embeddings.
if self.pre_process: if self.pre_process:
...@@ -302,25 +314,33 @@ class TransformerLanguageModel(MegatronModule): ...@@ -302,25 +314,33 @@ class TransformerLanguageModel(MegatronModule):
self._embedding_key = 'embedding' self._embedding_key = 'embedding'
# Transformer. # Transformer.
self.encoder = ParallelTransformer( # Encoder (usually set to True, False if part of an encoder-decoder
self.init_method, # architecture and in encoder-only stage).
output_layer_init_method, if self.add_encoder:
self_attn_mask_type=self.encoder_attn_mask_type, self.encoder = ParallelTransformer(
pre_process=self.pre_process, self.init_method,
post_process=self.post_process output_layer_init_method,
) self_attn_mask_type=self.encoder_attn_mask_type,
self._encoder_key = 'encoder' pre_process=self.pre_process,
post_process=self.post_process
# Decoder )
self._encoder_key = 'encoder'
else:
self.encoder = None
# Decoder (usually set to False, True if part of an encoder-decoder
# architecture and in decoder-only stage).
if self.add_decoder: if self.add_decoder:
assert args.pipeline_model_parallel_size == 1, \
'pipeline parallelism is not supported in the presence of decoder'
self.decoder = ParallelTransformer( self.decoder = ParallelTransformer(
self.init_method, self.init_method,
output_layer_init_method, output_layer_init_method,
layer_type=LayerType.decoder, layer_type=LayerType.decoder,
self_attn_mask_type=self.decoder_attn_mask_type) self_attn_mask_type=self.decoder_attn_mask_type,
pre_process=self.pre_process,
post_process=self.post_process)
self._decoder_key = 'decoder' self._decoder_key = 'decoder'
else:
self.decoder = None
if self.post_process: if self.post_process:
# Pooler. # Pooler.
...@@ -330,7 +350,25 @@ class TransformerLanguageModel(MegatronModule): ...@@ -330,7 +350,25 @@ class TransformerLanguageModel(MegatronModule):
def set_input_tensor(self, input_tensor): def set_input_tensor(self, input_tensor):
""" See megatron.model.transformer.set_input_tensor()""" """ See megatron.model.transformer.set_input_tensor()"""
self.encoder.set_input_tensor(input_tensor) if self.add_encoder and self.add_decoder:
assert len(input_tensor) == 1, \
'input_tensor should only be length 1 for stage with both encoder and decoder'
self.encoder.set_input_tensor(input_tensor[0])
elif self.add_encoder:
assert len(input_tensor) == 1, \
'input_tensor should only be length 1 for stage with only encoder'
self.encoder.set_input_tensor(input_tensor[0])
elif self.add_decoder:
if len(input_tensor) == 2:
self.decoder.set_input_tensor(input_tensor[0])
self.encoder_hidden_state = input_tensor[1]
elif len(input_tensor) == 1:
self.decoder.set_input_tensor(None)
self.encoder_hidden_state = input_tensor[0]
else:
raise Exception('input_tensor must have either length 1 or 2')
else:
raise Exception('Stage must have at least either encoder or decoder')
def forward(self, enc_input_ids, enc_position_ids, enc_attn_mask, def forward(self, enc_input_ids, enc_position_ids, enc_attn_mask,
dec_input_ids=None, dec_position_ids=None, dec_attn_mask=None, dec_input_ids=None, dec_position_ids=None, dec_attn_mask=None,
...@@ -338,20 +376,22 @@ class TransformerLanguageModel(MegatronModule): ...@@ -338,20 +376,22 @@ class TransformerLanguageModel(MegatronModule):
get_key_value=False, pooling_sequence_index=0, get_key_value=False, pooling_sequence_index=0,
enc_hidden_states=None, output_enc_hidden=False): enc_hidden_states=None, output_enc_hidden=False):
# Embeddings. # Encoder embedding.
if self.pre_process: if self.pre_process:
embedding_output = self.embedding(enc_input_ids, enc_position_ids, encoder_input = self.embedding(enc_input_ids, enc_position_ids,
tokentype_ids=tokentype_ids) tokentype_ids=tokentype_ids)
encoder_input = embedding_output
else: else:
encoder_input = None encoder_input = None
# encoder. # Run encoder.
if enc_hidden_states is None: if enc_hidden_states is None:
encoder_output = self.encoder(encoder_input, if self.encoder is not None:
enc_attn_mask, encoder_output = self.encoder(encoder_input,
layer_past=layer_past, enc_attn_mask,
get_key_value=get_key_value) layer_past=layer_past,
get_key_value=get_key_value)
else:
encoder_output = self.encoder_hidden_state
else: else:
encoder_output = enc_hidden_states.to(encoder_input.dtype) encoder_output = enc_hidden_states.to(encoder_input.dtype)
...@@ -369,11 +409,15 @@ class TransformerLanguageModel(MegatronModule): ...@@ -369,11 +409,15 @@ class TransformerLanguageModel(MegatronModule):
else: else:
return encoder_output return encoder_output
# Decoder Embedding # Decoder embedding.
dec_embedding_output = self.embedding(dec_input_ids, if self.pre_process:
dec_position_ids) decoder_input = self.embedding(dec_input_ids,
# decoder dec_position_ids)
decoder_output = self.decoder(dec_embedding_output, else:
decoder_input = None
# Run decoder.
decoder_output = self.decoder(decoder_input,
dec_attn_mask, dec_attn_mask,
layer_past=layer_past, layer_past=layer_past,
get_key_value=get_key_value, get_key_value=get_key_value,
...@@ -394,9 +438,10 @@ class TransformerLanguageModel(MegatronModule): ...@@ -394,9 +438,10 @@ class TransformerLanguageModel(MegatronModule):
state_dict_[self._embedding_key] \ state_dict_[self._embedding_key] \
= self.embedding.state_dict_for_save_checkpoint( = self.embedding.state_dict_for_save_checkpoint(
destination, prefix, keep_vars) destination, prefix, keep_vars)
state_dict_[self._encoder_key] \ if self.add_encoder:
= self.encoder.state_dict_for_save_checkpoint( state_dict_[self._encoder_key] \
destination, prefix, keep_vars) = self.encoder.state_dict_for_save_checkpoint(
destination, prefix, keep_vars)
if self.post_process: if self.post_process:
if self.add_pooler: if self.add_pooler:
state_dict_[self._pooler_key] \ state_dict_[self._pooler_key] \
...@@ -425,38 +470,39 @@ class TransformerLanguageModel(MegatronModule): ...@@ -425,38 +470,39 @@ class TransformerLanguageModel(MegatronModule):
self.embedding.load_state_dict(state_dict_, strict=strict) self.embedding.load_state_dict(state_dict_, strict=strict)
# Encoder. # Encoder.
if self._encoder_key in state_dict: if self.add_encoder:
state_dict_ = state_dict[self._encoder_key] if self._encoder_key in state_dict:
# for backward compatibility. state_dict_ = state_dict[self._encoder_key]
elif 'transformer' in state_dict: # For backward compatibility.
state_dict_ = state_dict['transformer'] elif 'transformer' in state_dict:
else: state_dict_ = state_dict['transformer']
# for backward compatibility.
state_dict_ = {}
for key in state_dict.keys():
if 'transformer.' in key:
state_dict_[key.split('transformer.')[1]] = state_dict[key]
# for backward compatibility.
state_dict_self_attention = {}
for key in state_dict_.keys():
if '.attention.' in key:
state_dict_self_attention[key.replace(".attention.",
".self_attention.")] = state_dict_[key]
else: else:
state_dict_self_attention[key] = state_dict_[key] # For backward compatibility.
state_dict_ = state_dict_self_attention state_dict_ = {}
for key in state_dict.keys():
self.encoder.load_state_dict(state_dict_, strict=strict) if 'transformer.' in key:
state_dict_[key.split('transformer.')[1]] = state_dict[key]
# For backward compatibility.
state_dict_self_attention = {}
for key in state_dict_.keys():
if '.attention.' in key:
state_dict_self_attention[key.replace(".attention.",
".self_attention.")] = state_dict_[key]
else:
state_dict_self_attention[key] = state_dict_[key]
state_dict_ = state_dict_self_attention
self.encoder.load_state_dict(state_dict_, strict=strict)
# Pooler.
if self.post_process: if self.post_process:
# pooler
if self.add_pooler: if self.add_pooler:
assert 'pooler' in state_dict, \ assert 'pooler' in state_dict, \
'could not find data for pooler in the checkpoint' 'could not find data for pooler in the checkpoint'
self.pooler.load_state_dict(state_dict[self._pooler_key], self.pooler.load_state_dict(state_dict[self._pooler_key],
strict=strict) strict=strict)
# decoder # Decoder.
if self.add_decoder: if self.add_decoder:
assert 'decoder' in state_dict, \ assert 'decoder' in state_dict, \
'could not find data for pooler in the checkpoint' 'could not find data for pooler in the checkpoint'
......
...@@ -51,15 +51,14 @@ class MegatronModule(torch.nn.Module): ...@@ -51,15 +51,14 @@ class MegatronModule(torch.nn.Module):
def word_embeddings_weight(self): def word_embeddings_weight(self):
if mpu.is_pipeline_first_stage(ignore_virtual=True): if not mpu.is_pipeline_last_stage(ignore_virtual=True) or \
mpu.get_pipeline_model_parallel_world_size() == 1:
return self.language_model.embedding.word_embeddings.weight return self.language_model.embedding.word_embeddings.weight
if mpu.is_pipeline_last_stage(ignore_virtual=True): else:
if not self.share_word_embeddings: if not self.share_word_embeddings:
raise Exception('word_embeddings_weight() called for last ' raise Exception('word_embeddings_weight() called for last '
'stage, but share_word_embeddings is false') 'stage, but share_word_embeddings is false')
return self.word_embeddings.weight return self.word_embeddings.weight
raise Exception('word_embeddings_weight() should be '
'called for first and last stage only')
def initialize_word_embeddings(self, init_method_normal): def initialize_word_embeddings(self, init_method_normal):
...@@ -69,12 +68,12 @@ class MegatronModule(torch.nn.Module): ...@@ -69,12 +68,12 @@ class MegatronModule(torch.nn.Module):
'share_word_embeddings is false') 'share_word_embeddings is false')
# This function just initializes the word embeddings in the final stage # This function just initializes the word embeddings in the final stage
# when we are using pipeline parallelism. If we aren't using pipeline # when we are using pipeline parallelism. Nothing to do if we aren't
# parallelism there is nothing to do. # using pipeline parallelism.
if args.pipeline_model_parallel_size == 1: if args.pipeline_model_parallel_size == 1:
return return
# Parameters are shared between the word embeddings layer, and the # Parameters are shared between the word embeddings layers, and the
# heads at the end of the model. In a pipelined setup with more than # heads at the end of the model. In a pipelined setup with more than
# one stage, the initial embedding layer and the head are on different # one stage, the initial embedding layer and the head are on different
# workers, so we do the following: # workers, so we do the following:
...@@ -97,12 +96,34 @@ class MegatronModule(torch.nn.Module): ...@@ -97,12 +96,34 @@ class MegatronModule(torch.nn.Module):
self.word_embeddings.weight.data.fill_(0) self.word_embeddings.weight.data.fill_(0)
self.word_embeddings.weight.shared = True self.word_embeddings.weight.shared = True
# Zero out initial weights for decoder embedding.
# NOTE: We don't currently support T5 with the interleaved schedule.
if not mpu.is_pipeline_first_stage(ignore_virtual=True) and \
not mpu.is_pipeline_last_stage(ignore_virtual=True) and \
mpu.is_rank_in_embedding_group():
self.language_model.embedding.zero_parameters()
# Ensure that first and last stages have the same initial parameter # Ensure that first and last stages have the same initial parameter
# values. # values.
if torch.distributed.is_initialized(): if torch.distributed.is_initialized():
if mpu.is_pipeline_first_stage() or mpu.is_pipeline_last_stage(): if mpu.is_rank_in_embedding_group():
torch.distributed.all_reduce(self.word_embeddings_weight().data, torch.distributed.all_reduce(self.word_embeddings_weight().data,
group=mpu.get_embedding_group()) group=mpu.get_embedding_group())
# All-reduce other embeddings as well as necessary. The last stage
# does not have these other embeddings, so just create placeholder
# tensors of the right shape with all zeros.
# NOTE: We don't currently support T5 with the interleaved schedule.
if args.pipeline_model_parallel_split_rank is not None:
# TODO: Support tokentype embedding.
dimensions = (args.max_position_embeddings, args.hidden_size)
if mpu.is_pipeline_last_stage(ignore_virtual=True):
position_embeddings = torch.nn.Embedding(*dimensions).cuda()
position_embeddings.weight.data.fill_(0)
else:
self.language_model.embedding.cuda()
position_embeddings = self.language_model.embedding.position_embeddings
torch.distributed.all_reduce(position_embeddings.weight.data,
group=mpu.get_embedding_group())
else: else:
print("WARNING! Distributed processes aren't initialized, so " print("WARNING! Distributed processes aren't initialized, so "
"word embeddings in the last layer are not initialized. " "word embeddings in the last layer are not initialized. "
......
...@@ -86,7 +86,13 @@ class T5LMHead(MegatronModule): ...@@ -86,7 +86,13 @@ class T5LMHead(MegatronModule):
class T5Model(MegatronModule): class T5Model(MegatronModule):
"""T5 Language model.""" """T5 Language model."""
def __init__(self, num_tokentypes=0, parallel_output=True): def __init__(self,
num_tokentypes=0,
parallel_output=True,
pre_process=True,
post_process=True,
add_encoder=True,
add_decoder=True):
super(T5Model, self).__init__() super(T5Model, self).__init__()
args = get_args() args = get_args()
...@@ -95,19 +101,29 @@ class T5Model(MegatronModule): ...@@ -95,19 +101,29 @@ class T5Model(MegatronModule):
init_method = init_method_normal(args.init_method_std) init_method = init_method_normal(args.init_method_std)
scaled_init_method = scaled_init_method_normal(args.init_method_std, scaled_init_method = scaled_init_method_normal(args.init_method_std,
args.num_layers) args.num_layers)
self.pre_process = pre_process
self.post_process = post_process
self.add_encoder = add_encoder
self.add_decoder = add_decoder
self.language_model, self._language_model_key = get_language_model( self.language_model, self._language_model_key = get_language_model(
num_tokentypes=num_tokentypes, num_tokentypes=num_tokentypes,
add_pooler=False, add_pooler=False,
add_decoder=True, add_encoder=add_encoder,
add_decoder=add_decoder,
encoder_attn_mask_type=AttnMaskType.padding, encoder_attn_mask_type=AttnMaskType.padding,
init_method=init_method, init_method=init_method,
scaled_init_method=scaled_init_method) scaled_init_method=scaled_init_method,
pre_process=self.pre_process,
post_process=self.post_process)
self.lm_head = T5LMHead( self.initialize_word_embeddings(init_method_normal)
self.language_model.embedding.word_embeddings.weight.size(0),
parallel_output) if self.post_process and self.add_decoder:
self._lm_head_key = 'lm_head' self.lm_head = T5LMHead(
self.word_embeddings_weight().size(0),
parallel_output)
self._lm_head_key = 'lm_head'
def set_input_tensor(self, input_tensor): def set_input_tensor(self, input_tensor):
"""See megatron.model.transformer.set_input_tensor()""" """See megatron.model.transformer.set_input_tensor()"""
...@@ -134,22 +150,28 @@ class T5Model(MegatronModule): ...@@ -134,22 +150,28 @@ class T5Model(MegatronModule):
tokentype_ids=tokentype_ids, tokentype_ids=tokentype_ids,
enc_hidden_states=enc_hidden_states) enc_hidden_states=enc_hidden_states)
decoder_output, encoder_output = lm_output if self.post_process and self.add_decoder:
decoder_output, encoder_output = lm_output
# Output. # Output.
lm_logits = self.lm_head(decoder_output, lm_logits = self.lm_head(decoder_output,
self.language_model.embedding.word_embeddings.weight) self.word_embeddings_weight())
if lm_labels is None: if lm_labels is None:
return lm_logits, encoder_output return lm_logits
else:
if self.fp16_lm_cross_entropy:
assert lm_logits.dtype == torch.half
lm_loss = mpu.vocab_parallel_cross_entropy(lm_logits, lm_labels)
else: else:
lm_loss = mpu.vocab_parallel_cross_entropy(lm_logits.float(), if self.fp16_lm_cross_entropy:
lm_labels) assert lm_logits.dtype == torch.half
return lm_loss, encoder_output lm_loss = mpu.vocab_parallel_cross_entropy(lm_logits, lm_labels)
else:
lm_loss = mpu.vocab_parallel_cross_entropy(lm_logits.float(),
lm_labels)
return lm_loss
elif self.add_decoder and not self.add_encoder:
decoder_output, encoder_output = lm_output
return decoder_output
else:
encoder_output = lm_output
return encoder_output
def state_dict_for_save_checkpoint(self, destination=None, prefix='', def state_dict_for_save_checkpoint(self, destination=None, prefix='',
keep_vars=False): keep_vars=False):
...@@ -160,9 +182,14 @@ class T5Model(MegatronModule): ...@@ -160,9 +182,14 @@ class T5Model(MegatronModule):
state_dict_[self._language_model_key] \ state_dict_[self._language_model_key] \
= self.language_model.state_dict_for_save_checkpoint( = self.language_model.state_dict_for_save_checkpoint(
destination, prefix, keep_vars) destination, prefix, keep_vars)
state_dict_[self._lm_head_key] \ if self.post_process and self.add_decoder:
= self.lm_head.state_dict_for_save_checkpoint( state_dict_[self._lm_head_key] \
destination, prefix, keep_vars) = self.lm_head.state_dict_for_save_checkpoint(
destination, prefix, keep_vars)
# Save word_embeddings.
if self.post_process and not self.pre_process and self.add_decoder:
state_dict_[self._word_embeddings_for_head_key] \
= self.word_embeddings.state_dict(destination, prefix, keep_vars)
return state_dict_ return state_dict_
def load_state_dict(self, state_dict, strict=True): def load_state_dict(self, state_dict, strict=True):
...@@ -170,5 +197,10 @@ class T5Model(MegatronModule): ...@@ -170,5 +197,10 @@ class T5Model(MegatronModule):
self.language_model.load_state_dict( self.language_model.load_state_dict(
state_dict[self._language_model_key], strict=strict) state_dict[self._language_model_key], strict=strict)
self.lm_head.load_state_dict(state_dict[self._lm_head_key], if self.post_process and self.add_decoder:
strict=strict) self.lm_head.load_state_dict(state_dict[self._lm_head_key],
strict=strict)
# Load word embeddings.
if self.post_process and not self.pre_process and self.add_decoder:
self.word_embeddings.load_state_dict(
state_dict[self._word_embeddings_for_head_key], strict=strict)
...@@ -21,7 +21,7 @@ import torch.nn.functional as F ...@@ -21,7 +21,7 @@ import torch.nn.functional as F
from megatron import get_args from megatron import get_args
from megatron import mpu from megatron import mpu
from .module import MegatronModule from .module import MegatronModule
from megatron.model.enums import AttnMaskType, LayerType, AttnType from megatron.model.enums import AttnMaskType, ModelType, LayerType, AttnType
from megatron.model import LayerNorm from megatron.model import LayerNorm
from megatron.model.fused_softmax import FusedScaleMaskSoftmax from megatron.model.fused_softmax import FusedScaleMaskSoftmax
from megatron.model.fused_bias_gelu import bias_gelu_impl from megatron.model.fused_bias_gelu import bias_gelu_impl
...@@ -548,9 +548,8 @@ class ParallelTransformer(MegatronModule): ...@@ -548,9 +548,8 @@ class ParallelTransformer(MegatronModule):
self.checkpoint_num_layers = args.checkpoint_num_layers self.checkpoint_num_layers = args.checkpoint_num_layers
# Number of layers. # Number of layers.
assert args.num_layers % mpu.get_pipeline_model_parallel_world_size() == 0, \ self.num_layers = mpu.get_num_layers(
'num_layers must be divisible by pipeline_model_parallel_size' args, args.model_type == ModelType.encoder_and_decoder)
self.num_layers = args.num_layers // mpu.get_pipeline_model_parallel_world_size()
# Transformer layers. # Transformer layers.
def build_layer(layer_number): def build_layer(layer_number):
......
...@@ -31,6 +31,10 @@ from .initialize import get_pipeline_model_parallel_group ...@@ -31,6 +31,10 @@ from .initialize import get_pipeline_model_parallel_group
from .initialize import get_tensor_model_parallel_rank, set_tensor_model_parallel_rank from .initialize import get_tensor_model_parallel_rank, set_tensor_model_parallel_rank
from .initialize import get_pipeline_model_parallel_rank, set_pipeline_model_parallel_rank from .initialize import get_pipeline_model_parallel_rank, set_pipeline_model_parallel_rank
from .initialize import is_pipeline_first_stage, is_pipeline_last_stage from .initialize import is_pipeline_first_stage, is_pipeline_last_stage
from .initialize import is_rank_in_embedding_group
from .initialize import is_pipeline_stage_before_split, is_pipeline_stage_after_split
from .initialize import is_pipeline_stage_at_split
from .initialize import get_num_layers
from .initialize import get_tensor_model_parallel_src_rank from .initialize import get_tensor_model_parallel_src_rank
from .initialize import get_pipeline_model_parallel_first_rank from .initialize import get_pipeline_model_parallel_first_rank
from .initialize import get_pipeline_model_parallel_last_rank from .initialize import get_pipeline_model_parallel_last_rank
......
...@@ -34,6 +34,7 @@ _DATA_PARALLEL_GROUP = None ...@@ -34,6 +34,7 @@ _DATA_PARALLEL_GROUP = None
_VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK = None _VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK = None
_VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE = None _VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE = None
_PIPELINE_MODEL_PARALLEL_SPLIT_RANK = None
# These values enable us to change the mpu sizes on the fly. # These values enable us to change the mpu sizes on the fly.
_MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE = None _MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE = None
...@@ -41,8 +42,11 @@ _MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE = None ...@@ -41,8 +42,11 @@ _MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE = None
_MPU_TENSOR_MODEL_PARALLEL_RANK = None _MPU_TENSOR_MODEL_PARALLEL_RANK = None
_MPU_PIPELINE_MODEL_PARALLEL_RANK = None _MPU_PIPELINE_MODEL_PARALLEL_RANK = None
# A list of ranks that have a copy of the embedding.
_EMBEDDING_GLOBAL_RANKS = None
# A list of global ranks for each pipeline group to ease calculation of the source # A list of global ranks for each pipeline group to ease calculation of the source
# rank when broadcasting from the first or last pipeline stage # rank when broadcasting from the first or last pipeline stage.
_PIPELINE_GLOBAL_RANKS = None _PIPELINE_GLOBAL_RANKS = None
def is_unitialized(): def is_unitialized():
...@@ -52,13 +56,19 @@ def is_unitialized(): ...@@ -52,13 +56,19 @@ def is_unitialized():
def initialize_model_parallel(tensor_model_parallel_size_=1, def initialize_model_parallel(tensor_model_parallel_size_=1,
pipeline_model_parallel_size_=1, pipeline_model_parallel_size_=1,
virtual_pipeline_model_parallel_size_=None): virtual_pipeline_model_parallel_size_=None,
pipeline_model_parallel_split_rank_=None):
""" """
Initialize model data parallel groups. Initialize model data parallel groups.
Arguments: Arguments:
tensor_model_parallel_size: number of GPUs used to parallelize model tensor. tensor_model_parallel_size: number of GPUs used for tensor model parallelism.
pipeline_model_parallel_size: number of GPUs used to parallelize model pipeline. pipeline_model_parallel_size: number of GPUs used for pipeline model parallelism.
virtual_pipeline_model_parallel_size: number of virtual stages (interleaved
pipeline).
pipeline_model_parallel_split_rank: for models with both encoder and decoder,
rank in pipeline with split point.
Let's say we have a total of 16 GPUs denoted by g0 ... g15 and we Let's say we have a total of 16 GPUs denoted by g0 ... g15 and we
use 2 GPUs to parallelize the model tensor, and 4 GPUs to parallelize use 2 GPUs to parallelize the model tensor, and 4 GPUs to parallelize
...@@ -101,6 +111,10 @@ def initialize_model_parallel(tensor_model_parallel_size_=1, ...@@ -101,6 +111,10 @@ def initialize_model_parallel(tensor_model_parallel_size_=1,
_VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK = 0 _VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK = 0
_VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE = virtual_pipeline_model_parallel_size_ _VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE = virtual_pipeline_model_parallel_size_
if pipeline_model_parallel_split_rank_ is not None:
global _PIPELINE_MODEL_PARALLEL_SPLIT_RANK
_PIPELINE_MODEL_PARALLEL_SPLIT_RANK = pipeline_model_parallel_split_rank_
rank = torch.distributed.get_rank() rank = torch.distributed.get_rank()
# Build the data-parallel groups. # Build the data-parallel groups.
...@@ -148,6 +162,7 @@ def initialize_model_parallel(tensor_model_parallel_size_=1, ...@@ -148,6 +162,7 @@ def initialize_model_parallel(tensor_model_parallel_size_=1,
assert _PIPELINE_MODEL_PARALLEL_GROUP is None, \ assert _PIPELINE_MODEL_PARALLEL_GROUP is None, \
'pipeline model parallel group is already initialized' 'pipeline model parallel group is already initialized'
global _EMBEDDING_GROUP global _EMBEDDING_GROUP
global _EMBEDDING_GLOBAL_RANKS
assert _EMBEDDING_GROUP is None, \ assert _EMBEDDING_GROUP is None, \
'embedding group is already initialized' 'embedding group is already initialized'
for i in range(num_pipeline_model_parallel_groups): for i in range(num_pipeline_model_parallel_groups):
...@@ -161,11 +176,18 @@ def initialize_model_parallel(tensor_model_parallel_size_=1, ...@@ -161,11 +176,18 @@ def initialize_model_parallel(tensor_model_parallel_size_=1,
# first and last stages). # first and last stages).
if len(ranks) > 1: if len(ranks) > 1:
embedding_ranks = [ranks[0], ranks[-1]] embedding_ranks = [ranks[0], ranks[-1]]
if pipeline_model_parallel_split_rank_ is not None and \
pipeline_model_parallel_split_rank_ not in embedding_ranks:
embedding_ranks = [ranks[0],
ranks[pipeline_model_parallel_split_rank_],
ranks[-1]]
else: else:
embedding_ranks = ranks embedding_ranks = ranks
group = torch.distributed.new_group(embedding_ranks) group = torch.distributed.new_group(embedding_ranks)
if rank in embedding_ranks: if rank in embedding_ranks:
_EMBEDDING_GROUP = group _EMBEDDING_GROUP = group
if rank in ranks:
_EMBEDDING_GLOBAL_RANKS = embedding_ranks
def model_parallel_is_initialized(): def model_parallel_is_initialized():
...@@ -268,6 +290,30 @@ def get_pipeline_model_parallel_rank(): ...@@ -268,6 +290,30 @@ def get_pipeline_model_parallel_rank():
return torch.distributed.get_rank(group=get_pipeline_model_parallel_group()) return torch.distributed.get_rank(group=get_pipeline_model_parallel_group())
def get_num_layers(args, is_encoder_and_decoder_model):
"""Compute the number of transformer layers resident on the current rank."""
if get_pipeline_model_parallel_world_size() > 1:
if is_encoder_and_decoder_model:
assert args.pipeline_model_parallel_split_rank is not None
num_ranks_in_encoder = args.pipeline_model_parallel_split_rank
num_ranks_in_decoder = get_pipeline_model_parallel_world_size() - num_ranks_in_encoder
assert args.num_layers % num_ranks_in_encoder == 0, \
'num_layers must be divisible by number of ranks given to encoder'
assert args.num_layers % num_ranks_in_decoder == 0, \
'num_layers must be divisible by number of ranks given to decoder'
if is_pipeline_stage_before_split():
num_layers = args.num_layers // num_ranks_in_encoder
else:
num_layers = args.num_layers // num_ranks_in_decoder
else:
assert args.num_layers % get_pipeline_model_parallel_world_size() == 0, \
'num_layers must be divisible by pipeline_model_parallel_size'
num_layers = args.num_layers // get_pipeline_model_parallel_world_size()
else:
num_layers = args.num_layers
return num_layers
def is_pipeline_first_stage(ignore_virtual=False): def is_pipeline_first_stage(ignore_virtual=False):
"""Return True if in the first pipeline model-parallel stage, False otherwise.""" """Return True if in the first pipeline model-parallel stage, False otherwise."""
if not ignore_virtual: if not ignore_virtual:
...@@ -290,6 +336,61 @@ def is_pipeline_last_stage(ignore_virtual=False): ...@@ -290,6 +336,61 @@ def is_pipeline_last_stage(ignore_virtual=False):
get_pipeline_model_parallel_world_size() - 1) get_pipeline_model_parallel_world_size() - 1)
def is_rank_in_embedding_group(ignore_virtual=False):
"""Return true if current rank is in embedding group, False otherwise."""
rank = torch.distributed.get_rank()
global _EMBEDDING_GLOBAL_RANKS
if ignore_virtual:
return rank in _EMBEDDING_GLOBAL_RANKS
if rank in _EMBEDDING_GLOBAL_RANKS:
if rank == _EMBEDDING_GLOBAL_RANKS[0]:
return is_pipeline_first_stage(ignore_virtual=False)
elif rank == _EMBEDDING_GLOBAL_RANKS[-1]:
return is_pipeline_last_stage(ignore_virtual=False)
else:
return True
return False
def is_pipeline_stage_before_split(rank=None):
"""Return True if pipeline stage executes encoder block for a model
with both encoder and decoder."""
if get_pipeline_model_parallel_world_size() == 1:
return True
if rank is None:
rank = get_pipeline_model_parallel_rank()
global _PIPELINE_MODEL_PARALLEL_SPLIT_RANK
if _PIPELINE_MODEL_PARALLEL_SPLIT_RANK is None:
return True
if rank < _PIPELINE_MODEL_PARALLEL_SPLIT_RANK:
return True
return False
def is_pipeline_stage_after_split(rank=None):
"""Return True if pipeline stage executes decoder block for a model
with both encoder and decoder."""
if get_pipeline_model_parallel_world_size() == 1:
return True
if rank is None:
rank = get_pipeline_model_parallel_rank()
global _PIPELINE_MODEL_PARALLEL_SPLIT_RANK
if _PIPELINE_MODEL_PARALLEL_SPLIT_RANK is None:
return True
if rank >= _PIPELINE_MODEL_PARALLEL_SPLIT_RANK:
return True
return False
def is_pipeline_stage_at_split():
"""Return true if pipeline stage executes decoder block and next
stage executes encoder block for a model with both encoder and
decoder."""
rank = get_pipeline_model_parallel_rank()
return is_pipeline_stage_before_split(rank) and \
is_pipeline_stage_after_split(rank+1)
def get_virtual_pipeline_model_parallel_rank(): def get_virtual_pipeline_model_parallel_rank():
"""Return the virtual pipeline-parallel rank.""" """Return the virtual pipeline-parallel rank."""
global _VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK global _VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK
......
...@@ -58,7 +58,8 @@ def clip_grad_norm_fp32(parameters, max_norm, norm_type=2): ...@@ -58,7 +58,8 @@ def clip_grad_norm_fp32(parameters, max_norm, norm_type=2):
grad_not_none = param.grad is not None grad_not_none = param.grad is not None
is_not_shared = param_is_not_shared(param) is_not_shared = param_is_not_shared(param)
is_not_tp_duplicate = param_is_not_tensor_parallel_duplicate(param) is_not_tp_duplicate = param_is_not_tensor_parallel_duplicate(param)
grad = param.grad.detach() if grad_not_none:
grad = param.grad.detach()
if grad_not_none: if grad_not_none:
# Make sure the grads are in fp32 # Make sure the grads are in fp32
assert param.grad.type() == 'torch.cuda.FloatTensor' assert param.grad.type() == 'torch.cuda.FloatTensor'
......
...@@ -173,7 +173,7 @@ class Float16OptimizerWithFloat16Params(MegatronOptimizer): ...@@ -173,7 +173,7 @@ class Float16OptimizerWithFloat16Params(MegatronOptimizer):
a `main_grad` field. If this is set, we are assuming a `main_grad` field. If this is set, we are assuming
that the model parameters are store in the `main_grad` that the model parameters are store in the `main_grad`
field instead of the typical `grad` field. This happens field instead of the typical `grad` field. This happens
for the DDP cases where there is a contihuous buffer for the DDP cases where there is a continuous buffer
holding the gradients. For example for bfloat16, we want holding the gradients. For example for bfloat16, we want
to do gradient accumulation and all-reduces in float32 to do gradient accumulation and all-reduces in float32
and as a result we store those gradients in the main_grad. and as a result we store those gradients in the main_grad.
...@@ -305,7 +305,7 @@ class Float16OptimizerWithFloat16Params(MegatronOptimizer): ...@@ -305,7 +305,7 @@ class Float16OptimizerWithFloat16Params(MegatronOptimizer):
for model_group, main_group in zip(self.float16_groups, for model_group, main_group in zip(self.float16_groups,
self.fp32_from_float16_groups): self.fp32_from_float16_groups):
for model_param, main_param in zip(model_group, main_group): for model_param, main_param in zip(model_group, main_group):
if self.params_have_main_grad: if self.params_have_main_grad and hasattr(model_param, 'main_grad'):
main_param.grad = model_param.main_grad.float() main_param.grad = model_param.main_grad.float()
else: else:
if model_param.grad is not None: if model_param.grad is not None:
......
...@@ -22,8 +22,8 @@ from megatron import mpu ...@@ -22,8 +22,8 @@ from megatron import mpu
def _communicate(tensor_send_next, tensor_send_prev, recv_prev, recv_next, def _communicate(tensor_send_next, tensor_send_prev, recv_prev, recv_next,
use_ring_exchange=False, tensor_shape=None, tensor_shape,
override_scatter_gather_tensors_in_pipeline=False, use_ring_exchange=False,
dtype_=None): dtype_=None):
"""Communicate tensors between stages. Used as helper method in other """Communicate tensors between stages. Used as helper method in other
communication methods that are used in megatron/schedules.py. communication methods that are used in megatron/schedules.py.
...@@ -37,16 +37,13 @@ def _communicate(tensor_send_next, tensor_send_prev, recv_prev, recv_next, ...@@ -37,16 +37,13 @@ def _communicate(tensor_send_next, tensor_send_prev, recv_prev, recv_next,
previous rank. previous rank.
recv_next: boolean for whether tensor should be received from recv_next: boolean for whether tensor should be received from
next rank. next rank.
tensor_shape: shape of tensor to receive (this method assumes that all
tensors sent and received in a single function call are
the same shape).
use_ring_exchange: boolean for whether torch.distributed.ring_exchange() use_ring_exchange: boolean for whether torch.distributed.ring_exchange()
API should be used. API should be used.
tensor_shape: optional, use when the input sequence contains less dtype_: optional, this is used when the tensor that needs to be
tokens than the default sequence length communicated is different from args.params_dtype.
override_scatter_gather_tensors_in_pipeline: optional, this is used
when tensor_shape is
provided to overwide
scatter gather tensors
dtype_: optional, this is used when tensor_shape is provied and what
is the type of tensor_shape
Returns: Returns:
(tensor_recv_prev, tensor_recv_next) (tensor_recv_prev, tensor_recv_next)
""" """
...@@ -56,12 +53,15 @@ def _communicate(tensor_send_next, tensor_send_prev, recv_prev, recv_next, ...@@ -56,12 +53,15 @@ def _communicate(tensor_send_next, tensor_send_prev, recv_prev, recv_next,
# if needed. # if needed.
tensor_recv_prev = None tensor_recv_prev = None
tensor_recv_next = None tensor_recv_next = None
if tensor_shape is None: override_scatter_gather_tensors_in_pipeline = False
tensor_shape = (args.seq_length, args.micro_batch_size, args.hidden_size) if args.scatter_gather_tensors_in_pipeline:
if not override_scatter_gather_tensors_in_pipeline and \ tensor_chunk_shape = reduce(operator.mul, tensor_shape, 1)
args.scatter_gather_tensors_in_pipeline: if tensor_chunk_shape % mpu.get_tensor_model_parallel_world_size() == 0:
tensor_chunk_shape = reduce(operator.mul, tensor_shape, 1) // \ tensor_chunk_shape = tensor_chunk_shape // \
mpu.get_tensor_model_parallel_world_size() mpu.get_tensor_model_parallel_world_size()
else:
tensor_chunk_shape = tensor_shape
override_scatter_gather_tensors_in_pipeline = True
else: else:
tensor_chunk_shape = tensor_shape tensor_chunk_shape = tensor_shape
dtype = args.params_dtype dtype = args.params_dtype
...@@ -143,9 +143,7 @@ def _communicate(tensor_send_next, tensor_send_prev, recv_prev, recv_next, ...@@ -143,9 +143,7 @@ def _communicate(tensor_send_next, tensor_send_prev, recv_prev, recv_next,
return tensor_recv_prev, tensor_recv_next return tensor_recv_prev, tensor_recv_next
def recv_forward(tensor_shape=None, def recv_forward(tensor_shape, dtype_=None, timers=None):
override_scatter_gather_tensors_in_pipeline=False,
dtype_=None, timers=None):
"""Receive tensor from previous rank in pipeline (forward receive).""" """Receive tensor from previous rank in pipeline (forward receive)."""
if mpu.is_pipeline_first_stage(): if mpu.is_pipeline_first_stage():
...@@ -159,15 +157,13 @@ def recv_forward(tensor_shape=None, ...@@ -159,15 +157,13 @@ def recv_forward(tensor_shape=None,
recv_prev=True, recv_prev=True,
recv_next=False, recv_next=False,
tensor_shape=tensor_shape, tensor_shape=tensor_shape,
override_scatter_gather_tensors_in_pipeline=\
override_scatter_gather_tensors_in_pipeline,
dtype_=dtype_) dtype_=dtype_)
if timers is not None: if timers is not None:
timers('forward-recv').stop() timers('forward-recv').stop()
return input_tensor return input_tensor
def recv_backward(timers=None): def recv_backward(tensor_shape, timers=None):
"""Receive tensor from next rank in pipeline (backward receive).""" """Receive tensor from next rank in pipeline (backward receive)."""
if mpu.is_pipeline_last_stage(): if mpu.is_pipeline_last_stage():
output_tensor_grad = None output_tensor_grad = None
...@@ -178,15 +174,14 @@ def recv_backward(timers=None): ...@@ -178,15 +174,14 @@ def recv_backward(timers=None):
tensor_send_next=None, tensor_send_next=None,
tensor_send_prev=None, tensor_send_prev=None,
recv_prev=False, recv_prev=False,
recv_next=True) recv_next=True,
tensor_shape=tensor_shape)
if timers is not None: if timers is not None:
timers('backward-recv').stop() timers('backward-recv').stop()
return output_tensor_grad return output_tensor_grad
def send_forward(output_tensor, timers=None, def send_forward(output_tensor, tensor_shape, dtype_=None, timers=None):
override_scatter_gather_tensors_in_pipeline=False,
dtype_=None):
"""Send tensor to next rank in pipeline (forward send).""" """Send tensor to next rank in pipeline (forward send)."""
if not mpu.is_pipeline_last_stage(): if not mpu.is_pipeline_last_stage():
...@@ -197,14 +192,13 @@ def send_forward(output_tensor, timers=None, ...@@ -197,14 +192,13 @@ def send_forward(output_tensor, timers=None,
tensor_send_prev=None, tensor_send_prev=None,
recv_prev=False, recv_prev=False,
recv_next=False, recv_next=False,
override_scatter_gather_tensors_in_pipeline=\ tensor_shape=tensor_shape,
override_scatter_gather_tensors_in_pipeline,
dtype_=dtype_) dtype_=dtype_)
if timers is not None: if timers is not None:
timers('forward-send').stop() timers('forward-send').stop()
def send_backward(input_tensor_grad, timers=None): def send_backward(input_tensor_grad, tensor_shape, timers=None):
"""Send tensor to previous rank in pipeline (backward send).""" """Send tensor to previous rank in pipeline (backward send)."""
if not mpu.is_pipeline_first_stage(): if not mpu.is_pipeline_first_stage():
if timers is not None: if timers is not None:
...@@ -213,12 +207,13 @@ def send_backward(input_tensor_grad, timers=None): ...@@ -213,12 +207,13 @@ def send_backward(input_tensor_grad, timers=None):
tensor_send_next=None, tensor_send_next=None,
tensor_send_prev=input_tensor_grad, tensor_send_prev=input_tensor_grad,
recv_prev=False, recv_prev=False,
recv_next=False) recv_next=False,
tensor_shape=tensor_shape)
if timers is not None: if timers is not None:
timers('backward-send').stop() timers('backward-send').stop()
def send_forward_recv_backward(output_tensor, timers=None): def send_forward_recv_backward(output_tensor, tensor_shape, timers=None):
"""Batched send and recv with next rank in pipeline.""" """Batched send and recv with next rank in pipeline."""
if mpu.is_pipeline_last_stage(): if mpu.is_pipeline_last_stage():
output_tensor_grad = None output_tensor_grad = None
...@@ -229,13 +224,14 @@ def send_forward_recv_backward(output_tensor, timers=None): ...@@ -229,13 +224,14 @@ def send_forward_recv_backward(output_tensor, timers=None):
tensor_send_next=output_tensor, tensor_send_next=output_tensor,
tensor_send_prev=None, tensor_send_prev=None,
recv_prev=False, recv_prev=False,
recv_next=True) recv_next=True,
tensor_shape=tensor_shape)
if timers is not None: if timers is not None:
timers('forward-send-backward-recv').stop() timers('forward-send-backward-recv').stop()
return output_tensor_grad return output_tensor_grad
def send_backward_recv_forward(input_tensor_grad, timers=None): def send_backward_recv_forward(input_tensor_grad, tensor_shape, timers=None):
"""Batched send and recv with previous rank in pipeline.""" """Batched send and recv with previous rank in pipeline."""
if mpu.is_pipeline_first_stage(): if mpu.is_pipeline_first_stage():
input_tensor = None input_tensor = None
...@@ -246,13 +242,14 @@ def send_backward_recv_forward(input_tensor_grad, timers=None): ...@@ -246,13 +242,14 @@ def send_backward_recv_forward(input_tensor_grad, timers=None):
tensor_send_next=None, tensor_send_next=None,
tensor_send_prev=input_tensor_grad, tensor_send_prev=input_tensor_grad,
recv_prev=True, recv_prev=True,
recv_next=False) recv_next=False,
tensor_shape=tensor_shape)
if timers is not None: if timers is not None:
timers('backward-send-forward-recv').stop() timers('backward-send-forward-recv').stop()
return input_tensor return input_tensor
def send_forward_recv_forward(output_tensor, recv_prev, timers=None): def send_forward_recv_forward(output_tensor, recv_prev, tensor_shape, timers=None):
"""Batched recv from previous rank and send to next rank in pipeline.""" """Batched recv from previous rank and send to next rank in pipeline."""
if timers is not None: if timers is not None:
timers('forward-send-forward-recv').start() timers('forward-send-forward-recv').start()
...@@ -260,13 +257,14 @@ def send_forward_recv_forward(output_tensor, recv_prev, timers=None): ...@@ -260,13 +257,14 @@ def send_forward_recv_forward(output_tensor, recv_prev, timers=None):
tensor_send_next=output_tensor, tensor_send_next=output_tensor,
tensor_send_prev=None, tensor_send_prev=None,
recv_prev=recv_prev, recv_prev=recv_prev,
recv_next=False) recv_next=False,
tensor_shape=tensor_shape)
if timers is not None: if timers is not None:
timers('forward-send-forward-recv').stop() timers('forward-send-forward-recv').stop()
return input_tensor return input_tensor
def send_backward_recv_backward(input_tensor_grad, recv_next, timers=None): def send_backward_recv_backward(input_tensor_grad, recv_next, tensor_shape, timers=None):
"""Batched recv from next rank and send to previous rank in pipeline.""" """Batched recv from next rank and send to previous rank in pipeline."""
if timers is not None: if timers is not None:
timers('backward-send-backward-recv').start() timers('backward-send-backward-recv').start()
...@@ -274,7 +272,8 @@ def send_backward_recv_backward(input_tensor_grad, recv_next, timers=None): ...@@ -274,7 +272,8 @@ def send_backward_recv_backward(input_tensor_grad, recv_next, timers=None):
tensor_send_next=None, tensor_send_next=None,
tensor_send_prev=input_tensor_grad, tensor_send_prev=input_tensor_grad,
recv_prev=False, recv_prev=False,
recv_next=recv_next) recv_next=recv_next,
tensor_shape=tensor_shape)
if timers is not None: if timers is not None:
timers('backward-send-backward-recv').stop() timers('backward-send-backward-recv').stop()
return output_tensor_grad return output_tensor_grad
...@@ -282,7 +281,7 @@ def send_backward_recv_backward(input_tensor_grad, recv_next, timers=None): ...@@ -282,7 +281,7 @@ def send_backward_recv_backward(input_tensor_grad, recv_next, timers=None):
def send_forward_backward_recv_forward_backward( def send_forward_backward_recv_forward_backward(
output_tensor, input_tensor_grad, recv_prev, output_tensor, input_tensor_grad, recv_prev,
recv_next, timers=None): recv_next, tensor_shape, timers=None):
"""Batched send and recv with previous and next ranks in pipeline.""" """Batched send and recv with previous and next ranks in pipeline."""
if timers is not None: if timers is not None:
timers('forward-backward-send-forward-backward-recv').start() timers('forward-backward-send-forward-backward-recv').start()
...@@ -290,7 +289,8 @@ def send_forward_backward_recv_forward_backward( ...@@ -290,7 +289,8 @@ def send_forward_backward_recv_forward_backward(
tensor_send_next=output_tensor, tensor_send_next=output_tensor,
tensor_send_prev=input_tensor_grad, tensor_send_prev=input_tensor_grad,
recv_prev=recv_prev, recv_prev=recv_prev,
recv_next=recv_next) recv_next=recv_next,
tensor_shape=tensor_shape)
if timers is not None: if timers is not None:
timers('forward-backward-send-forward-backward-recv').stop() timers('forward-backward-send-forward-backward-recv').stop()
return input_tensor, output_tensor_grad return input_tensor, output_tensor_grad
...@@ -25,6 +25,8 @@ from megatron import p2p_communication ...@@ -25,6 +25,8 @@ from megatron import p2p_communication
from megatron.utils import unwrap_model from megatron.utils import unwrap_model
from megatron.model import DistributedDataParallel as LocalDDP from megatron.model import DistributedDataParallel as LocalDDP
from megatron.model import Float16Module from megatron.model import Float16Module
from megatron.model import ModelType
def get_forward_backward_func(): def get_forward_backward_func():
args = get_args() args = get_args()
...@@ -48,11 +50,18 @@ def forward_step(forward_step_func, data_iterator, model, input_tensor, losses_r ...@@ -48,11 +50,18 @@ def forward_step(forward_step_func, data_iterator, model, input_tensor, losses_r
passed-in input_tensor is used. passed-in input_tensor is used.
Returns output tensor.""" Returns output tensor."""
args = get_args()
timers = get_timers() timers = get_timers()
timers('forward-compute').start() timers('forward-compute').start()
unwrapped_model = unwrap_model( unwrapped_model = unwrap_model(
model, (torchDDP, LocalDDP, Float16Module)) model, (torchDDP, LocalDDP, Float16Module))
unwrap_output_tensor = False
if not isinstance(input_tensor, list):
input_tensor = [input_tensor]
unwrap_output_tensor = True
unwrapped_model.set_input_tensor(input_tensor) unwrapped_model.set_input_tensor(input_tensor)
output_tensor, loss_func = forward_step_func(data_iterator, model) output_tensor, loss_func = forward_step_func(data_iterator, model)
if mpu.is_pipeline_last_stage(): if mpu.is_pipeline_last_stage():
...@@ -62,7 +71,12 @@ def forward_step(forward_step_func, data_iterator, model, input_tensor, losses_r ...@@ -62,7 +71,12 @@ def forward_step(forward_step_func, data_iterator, model, input_tensor, losses_r
losses_reduced.append(loss_reduced) losses_reduced.append(loss_reduced)
timers('forward-compute').stop() timers('forward-compute').stop()
return output_tensor if mpu.is_pipeline_stage_after_split() and \
args.model_type == ModelType.encoder_and_decoder:
return [output_tensor, input_tensor[-1]]
if unwrap_output_tensor:
return output_tensor
return [output_tensor]
def backward_step(optimizer, input_tensor, output_tensor, output_tensor_grad): def backward_step(optimizer, input_tensor, output_tensor, output_tensor_grad):
...@@ -73,24 +87,53 @@ def backward_step(optimizer, input_tensor, output_tensor, output_tensor_grad): ...@@ -73,24 +87,53 @@ def backward_step(optimizer, input_tensor, output_tensor, output_tensor_grad):
Returns gradient of loss with respect to input tensor (None if first Returns gradient of loss with respect to input tensor (None if first
stage).""" stage)."""
# NOTE: This code currently can handle at most one skip connection. It
# needs to be modified slightly to support arbitrary numbers of skip
# connections.
args = get_args() args = get_args()
timers = get_timers() timers = get_timers()
timers('backward-compute').start() timers('backward-compute').start()
# Retain the grad on the input_tensor. # Retain the grad on the input_tensor.
if input_tensor is not None: unwrap_input_tensor_grad = False
input_tensor.retain_grad() if not isinstance(input_tensor, list):
input_tensor = [input_tensor]
unwrap_input_tensor_grad = True
for x in input_tensor:
if x is not None:
x.retain_grad()
if not isinstance(output_tensor, list):
output_tensor = [output_tensor]
if not isinstance(output_tensor_grad, list):
output_tensor_grad = [output_tensor_grad]
# Backward pass. # Backward pass.
if output_tensor_grad is None: if output_tensor_grad[0] is None:
output_tensor = optimizer.scale_loss(output_tensor) output_tensor = optimizer.scale_loss(output_tensor[0])
torch.autograd.backward(output_tensor, grad_tensors=output_tensor_grad) torch.autograd.backward(output_tensor[0], grad_tensors=output_tensor_grad[0])
# Collect the grad of the input_tensor. # Collect the grad of the input_tensor.
input_tensor_grad = None input_tensor_grad = [None]
if input_tensor is not None: if input_tensor is not None:
input_tensor_grad = input_tensor.grad input_tensor_grad = []
for x in input_tensor:
if x is None:
input_tensor_grad.append(None)
else:
input_tensor_grad.append(x.grad)
# Handle single skip connection if it exists (encoder_hidden_state in
# model with encoder and decoder).
if mpu.get_pipeline_model_parallel_world_size() > 1 and \
mpu.is_pipeline_stage_after_split() and \
args.model_type == ModelType.encoder_and_decoder:
if output_tensor_grad[1] is not None:
input_tensor_grad[-1].add_(output_tensor_grad[1])
if unwrap_input_tensor_grad:
input_tensor_grad = input_tensor_grad[0]
timers('backward-compute').stop() timers('backward-compute').stop()
...@@ -153,6 +196,9 @@ def forward_backward_pipelining_with_interleaving(forward_step_func, data_iterat ...@@ -153,6 +196,9 @@ def forward_backward_pipelining_with_interleaving(forward_step_func, data_iterat
pipeline_parallel_size = mpu.get_pipeline_model_parallel_world_size() pipeline_parallel_size = mpu.get_pipeline_model_parallel_world_size()
pipeline_parallel_rank = mpu.get_pipeline_model_parallel_rank() pipeline_parallel_rank = mpu.get_pipeline_model_parallel_rank()
args = get_args()
tensor_shape = (args.seq_length, args.micro_batch_size, args.hidden_size)
# Compute number of warmup and remaining microbatches. # Compute number of warmup and remaining microbatches.
num_model_chunks = len(model) num_model_chunks = len(model)
num_microbatches = get_num_microbatches() * num_model_chunks num_microbatches = get_num_microbatches() * num_model_chunks
...@@ -231,7 +277,7 @@ def forward_backward_pipelining_with_interleaving(forward_step_func, data_iterat ...@@ -231,7 +277,7 @@ def forward_backward_pipelining_with_interleaving(forward_step_func, data_iterat
# Run warmup forward passes. # Run warmup forward passes.
mpu.set_virtual_pipeline_model_parallel_rank(0) mpu.set_virtual_pipeline_model_parallel_rank(0)
input_tensors[0].append( input_tensors[0].append(
p2p_communication.recv_forward(timers=timers)) p2p_communication.recv_forward(tensor_shape, timers=timers))
for k in range(num_warmup_microbatches): for k in range(num_warmup_microbatches):
output_tensor = forward_step_helper(k) output_tensor = forward_step_helper(k)
...@@ -260,12 +306,15 @@ def forward_backward_pipelining_with_interleaving(forward_step_func, data_iterat ...@@ -260,12 +306,15 @@ def forward_backward_pipelining_with_interleaving(forward_step_func, data_iterat
p2p_communication.send_forward_backward_recv_forward_backward( p2p_communication.send_forward_backward_recv_forward_backward(
output_tensor, input_tensor_grad, output_tensor, input_tensor_grad,
recv_prev=recv_prev, recv_next=recv_next, recv_prev=recv_prev, recv_next=recv_next,
tensor_shape=tensor_shape,
timers=timers) timers=timers)
output_tensor_grads[num_model_chunks-1].append(output_tensor_grad) output_tensor_grads[num_model_chunks-1].append(output_tensor_grad)
else: else:
input_tensor = \ input_tensor = \
p2p_communication.send_forward_recv_forward( p2p_communication.send_forward_recv_forward(
output_tensor, recv_prev=recv_prev, timers=timers) output_tensor, recv_prev=recv_prev,
tensor_shape=tensor_shape,
timers=timers)
input_tensors[next_forward_model_chunk_id].append(input_tensor) input_tensors[next_forward_model_chunk_id].append(input_tensor)
# Run 1F1B in steady state. # Run 1F1B in steady state.
...@@ -329,7 +378,7 @@ def forward_backward_pipelining_with_interleaving(forward_step_func, data_iterat ...@@ -329,7 +378,7 @@ def forward_backward_pipelining_with_interleaving(forward_step_func, data_iterat
p2p_communication.send_forward_backward_recv_forward_backward( p2p_communication.send_forward_backward_recv_forward_backward(
output_tensor, input_tensor_grad, output_tensor, input_tensor_grad,
recv_prev=recv_prev, recv_next=recv_next, recv_prev=recv_prev, recv_next=recv_next,
timers=timers) tensor_shape=tensor_shape, timers=timers)
# Put input_tensor and output_tensor_grad in data structures in the # Put input_tensor and output_tensor_grad in data structures in the
# right location. # right location.
...@@ -343,7 +392,7 @@ def forward_backward_pipelining_with_interleaving(forward_step_func, data_iterat ...@@ -343,7 +392,7 @@ def forward_backward_pipelining_with_interleaving(forward_step_func, data_iterat
if not forward_only: if not forward_only:
if all_warmup_microbatches: if all_warmup_microbatches:
output_tensor_grads[num_model_chunks-1].append( output_tensor_grads[num_model_chunks-1].append(
p2p_communication.recv_backward(timers=timers)) p2p_communication.recv_backward(tensor_shape, timers=timers))
for k in range(num_microbatches_remaining, num_microbatches): for k in range(num_microbatches_remaining, num_microbatches):
input_tensor_grad = backward_step_helper(k) input_tensor_grad = backward_step_helper(k)
next_backward_model_chunk_id = get_model_chunk_id(k+1, forward=False) next_backward_model_chunk_id = get_model_chunk_id(k+1, forward=False)
...@@ -355,11 +404,107 @@ def forward_backward_pipelining_with_interleaving(forward_step_func, data_iterat ...@@ -355,11 +404,107 @@ def forward_backward_pipelining_with_interleaving(forward_step_func, data_iterat
recv_next = False recv_next = False
output_tensor_grads[next_backward_model_chunk_id].append( output_tensor_grads[next_backward_model_chunk_id].append(
p2p_communication.send_backward_recv_backward( p2p_communication.send_backward_recv_backward(
input_tensor_grad, recv_next=recv_next, timers=timers)) input_tensor_grad, recv_next=recv_next,
tensor_shape=tensor_shape,
timers=timers))
return losses_reduced return losses_reduced
def get_tensor_shapes(rank, model_type):
# Determine right tensor sizes (based on position of rank with respect to split
# rank) and model size.
# Send two tensors if model is T5 and rank is in decoder stage:
# first tensor is decoder (pre-transpose),
# second tensor is encoder (post-transpose).
# If model is T5 and rank is at the boundary:
# send one tensor (post-transpose from encoder).
# Otherwise, send one tensor (pre-transpose).
args = get_args()
tensor_shapes = []
if model_type == ModelType.encoder_and_decoder:
if mpu.is_pipeline_stage_before_split(rank):
# If next rank is after split, then need transpose for encoder_hidden_state.
if mpu.is_pipeline_stage_before_split(rank+1):
tensor_shapes.append((args.seq_length, args.micro_batch_size, args.hidden_size))
else:
tensor_shapes.append((args.micro_batch_size, args.seq_length, args.hidden_size))
else:
tensor_shapes.append((args.decoder_seq_length, args.micro_batch_size, args.hidden_size))
tensor_shapes.append((args.micro_batch_size, args.seq_length, args.hidden_size))
else:
tensor_shapes.append((args.seq_length, args.micro_batch_size, args.hidden_size))
return tensor_shapes
def recv_forward(tensor_shapes, timers):
input_tensors = []
for tensor_shape in tensor_shapes:
if tensor_shape is None:
input_tensors.append(None)
else:
input_tensors.append(p2p_communication.recv_forward(tensor_shape,
timers=timers))
return input_tensors
def recv_backward(tensor_shapes, timers):
output_tensor_grads = []
for tensor_shape in tensor_shapes:
if tensor_shape is None:
output_tensor_grads.append(None)
else:
output_tensor_grads.append(p2p_communication.recv_backward(tensor_shape,
timers=timers))
return output_tensor_grads
def send_forward(output_tensors, tensor_shapes, timers):
if not isinstance(output_tensors, list):
output_tensors = [output_tensors]
for (output_tensor, tensor_shape) in zip(output_tensors, tensor_shapes):
if tensor_shape is None:
continue
p2p_communication.send_forward(output_tensor, tensor_shape, timers=timers)
def send_backward(input_tensor_grads, tensor_shapes, timers):
if not isinstance(input_tensor_grads, list):
input_tensor_grads = [input_tensor_grads]
for (input_tensor_grad, tensor_shape) in zip(input_tensor_grads, tensor_shapes):
if tensor_shape is None:
continue
p2p_communication.send_backward(input_tensor_grad, tensor_shape, timers=timers)
def send_forward_recv_backward(output_tensors, tensor_shapes, timers):
if not isinstance(output_tensors, list):
output_tensors = [output_tensors]
output_tensor_grads = []
for (output_tensor, tensor_shape) in zip(output_tensors, tensor_shapes):
if tensor_shape is None:
output_tensor_grads.append(None)
continue
output_tensor_grad = p2p_communication.send_forward_recv_backward(
output_tensor, tensor_shape, timers=timers)
output_tensor_grads.append(output_tensor_grad)
return output_tensor_grads
def send_backward_recv_forward(input_tensor_grads, tensor_shapes, timers):
if not isinstance(input_tensor_grads, list):
input_tensor_grads = [input_tensor_grads]
input_tensors = []
for (input_tensor_grad, tensor_shape) in zip(input_tensor_grads, tensor_shapes):
if tensor_shape is None:
input_tensors.append(None)
continue
input_tensor = p2p_communication.send_backward_recv_forward(
input_tensor_grad, tensor_shape, timers=timers)
input_tensors.append(input_tensor)
return input_tensors
def forward_backward_pipelining_without_interleaving(forward_step_func, data_iterator, def forward_backward_pipelining_without_interleaving(forward_step_func, data_iterator,
model, optimizer, timers, model, optimizer, timers,
forward_only): forward_only):
...@@ -383,16 +528,23 @@ def forward_backward_pipelining_without_interleaving(forward_step_func, data_ite ...@@ -383,16 +528,23 @@ def forward_backward_pipelining_without_interleaving(forward_step_func, data_ite
num_microbatches_remaining = \ num_microbatches_remaining = \
num_microbatches - num_warmup_microbatches num_microbatches - num_warmup_microbatches
unwrapped_model = unwrap_model(
model, (torchDDP, LocalDDP, Float16Module))
model_type = unwrapped_model.model_type
rank = mpu.get_pipeline_model_parallel_rank()
recv_tensor_shapes = get_tensor_shapes(rank-1, model_type)
send_tensor_shapes = get_tensor_shapes(rank, model_type)
input_tensors = [] input_tensors = []
output_tensors = [] output_tensors = []
losses_reduced = [] losses_reduced = []
# Run warmup forward passes. # Run warmup forward passes.
for i in range(num_warmup_microbatches): for i in range(num_warmup_microbatches):
input_tensor = p2p_communication.recv_forward(timers=timers) input_tensor = recv_forward(recv_tensor_shapes, timers=timers)
output_tensor = forward_step(forward_step_func, data_iterator, model, output_tensor = forward_step(forward_step_func, data_iterator, model,
input_tensor, losses_reduced) input_tensor, losses_reduced)
p2p_communication.send_forward(output_tensor, timers=timers) send_forward(output_tensor, send_tensor_shapes, timers=timers)
input_tensors.append(input_tensor) input_tensors.append(input_tensor)
output_tensors.append(output_tensor) output_tensors.append(output_tensor)
...@@ -401,7 +553,7 @@ def forward_backward_pipelining_without_interleaving(forward_step_func, data_ite ...@@ -401,7 +553,7 @@ def forward_backward_pipelining_without_interleaving(forward_step_func, data_ite
# If all microbatches are run in warmup / cooldown phase, then no need to # If all microbatches are run in warmup / cooldown phase, then no need to
# receive this tensor here. # receive this tensor here.
if num_microbatches_remaining > 0: if num_microbatches_remaining > 0:
input_tensor = p2p_communication.recv_forward(timers=timers) input_tensor = recv_forward(recv_tensor_shapes, timers=timers)
# Run 1F1B in steady state. # Run 1F1B in steady state.
for i in range(num_microbatches_remaining): for i in range(num_microbatches_remaining):
...@@ -410,12 +562,12 @@ def forward_backward_pipelining_without_interleaving(forward_step_func, data_ite ...@@ -410,12 +562,12 @@ def forward_backward_pipelining_without_interleaving(forward_step_func, data_ite
output_tensor = forward_step(forward_step_func, data_iterator, model, output_tensor = forward_step(forward_step_func, data_iterator, model,
input_tensor, losses_reduced) input_tensor, losses_reduced)
if forward_only: if forward_only:
p2p_communication.send_forward(output_tensor, timers=timers) send_forward(output_tensor, send_tensor_shapes, timers=timers)
else: else:
output_tensor_grad = \ output_tensor_grad = \
p2p_communication.send_forward_recv_backward(output_tensor, send_forward_recv_backward(output_tensor,
timers=timers) send_tensor_shapes,
timers=timers)
# Add input_tensor and output_tensor to end of list, then pop from the # Add input_tensor and output_tensor to end of list, then pop from the
# start of the list for backward pass. # start of the list for backward pass.
input_tensors.append(input_tensor) input_tensors.append(input_tensor)
...@@ -423,7 +575,7 @@ def forward_backward_pipelining_without_interleaving(forward_step_func, data_ite ...@@ -423,7 +575,7 @@ def forward_backward_pipelining_without_interleaving(forward_step_func, data_ite
if forward_only: if forward_only:
if not last_iteration: if not last_iteration:
input_tensor = p2p_communication.recv_forward(timers=timers) input_tensor = recv_forward(recv_tensor_shapes, timers=timers)
else: else:
input_tensor, output_tensor = input_tensors.pop(0), output_tensors.pop(0) input_tensor, output_tensor = input_tensors.pop(0), output_tensors.pop(0)
...@@ -433,11 +585,11 @@ def forward_backward_pipelining_without_interleaving(forward_step_func, data_ite ...@@ -433,11 +585,11 @@ def forward_backward_pipelining_without_interleaving(forward_step_func, data_ite
if last_iteration: if last_iteration:
input_tensor = None input_tensor = None
p2p_communication.send_backward(input_tensor_grad, timers=timers) send_backward(input_tensor_grad, recv_tensor_shapes, timers=timers)
else: else:
input_tensor = \ input_tensor = \
p2p_communication.send_backward_recv_forward( send_backward_recv_forward(
input_tensor_grad, timers=timers) input_tensor_grad, recv_tensor_shapes, timers=timers)
# Run cooldown backward passes. # Run cooldown backward passes.
if not forward_only: if not forward_only:
...@@ -445,12 +597,12 @@ def forward_backward_pipelining_without_interleaving(forward_step_func, data_ite ...@@ -445,12 +597,12 @@ def forward_backward_pipelining_without_interleaving(forward_step_func, data_ite
input_tensor = input_tensors.pop(0) input_tensor = input_tensors.pop(0)
output_tensor = output_tensors.pop(0) output_tensor = output_tensors.pop(0)
output_tensor_grad = p2p_communication.recv_backward(timers=timers) output_tensor_grad = recv_backward(send_tensor_shapes, timers=timers)
input_tensor_grad = \ input_tensor_grad = \
backward_step(optimizer, input_tensor, output_tensor, backward_step(optimizer, input_tensor, output_tensor,
output_tensor_grad) output_tensor_grad)
p2p_communication.send_backward(input_tensor_grad, timers=timers) send_backward(input_tensor_grad, recv_tensor_shapes, timers=timers)
return losses_reduced return losses_reduced
...@@ -38,6 +38,7 @@ from megatron import print_rank_last ...@@ -38,6 +38,7 @@ from megatron import print_rank_last
from megatron.checkpointing import load_checkpoint from megatron.checkpointing import load_checkpoint
from megatron.checkpointing import save_checkpoint from megatron.checkpointing import save_checkpoint
from megatron.model import Float16Module from megatron.model import Float16Module
from megatron.model import ModelType
from megatron.optimizer import get_megatron_optimizer from megatron.optimizer import get_megatron_optimizer
from megatron.initialize import initialize_megatron from megatron.initialize import initialize_megatron
from megatron.initialize import write_args_to_tensorboard from megatron.initialize import write_args_to_tensorboard
...@@ -61,6 +62,7 @@ def print_datetime(string): ...@@ -61,6 +62,7 @@ def print_datetime(string):
def pretrain(train_valid_test_dataset_provider, def pretrain(train_valid_test_dataset_provider,
model_provider, model_provider,
model_type,
forward_step_func, forward_step_func,
extra_args_provider=None, extra_args_provider=None,
args_defaults={}): args_defaults={}):
...@@ -77,6 +79,7 @@ def pretrain(train_valid_test_dataset_provider, ...@@ -77,6 +79,7 @@ def pretrain(train_valid_test_dataset_provider,
train/valid/test dataset and returns `train, valid, test` datasets. train/valid/test dataset and returns `train, valid, test` datasets.
model_provider: a function that returns a vanilla version of the model_provider: a function that returns a vanilla version of the
model. By vanilla we mean a simple model on cpu with no fp16 or ddp. model. By vanilla we mean a simple model on cpu with no fp16 or ddp.
model_type: an enum that specifies the type of model being trained.
forward_step_func: a function that takes a `data iterator` and `model`, forward_step_func: a function that takes a `data iterator` and `model`,
and returns a `loss` scalar with a dictionary with key:values being and returns a `loss` scalar with a dictionary with key:values being
the info we would like to monitor during training, for example the info we would like to monitor during training, for example
...@@ -109,7 +112,8 @@ def pretrain(train_valid_test_dataset_provider, ...@@ -109,7 +112,8 @@ def pretrain(train_valid_test_dataset_provider,
# Model, optimizer, and learning rate. # Model, optimizer, and learning rate.
timers('model-and-optimizer-setup').start() timers('model-and-optimizer-setup').start()
model, optimizer, lr_scheduler = setup_model_and_optimizer(model_provider) model, optimizer, lr_scheduler = setup_model_and_optimizer(model_provider,
model_type)
timers('model-and-optimizer-setup').stop() timers('model-and-optimizer-setup').stop()
print_datetime('after model, optimizer, and learning rate ' print_datetime('after model, optimizer, and learning rate '
'scheduler are built') 'scheduler are built')
...@@ -189,13 +193,16 @@ def update_train_iters(args): ...@@ -189,13 +193,16 @@ def update_train_iters(args):
print_rank_0('setting training iterations to {}'.format(args.train_iters)) print_rank_0('setting training iterations to {}'.format(args.train_iters))
def get_model(model_provider_func): def get_model(model_provider_func, model_type):
"""Build the model.""" """Build the model."""
args = get_args() args = get_args()
args.model_type = model_type
# Build model. # Build model.
if mpu.get_pipeline_model_parallel_world_size() > 1 and \ if mpu.get_pipeline_model_parallel_world_size() > 1 and \
args.virtual_pipeline_model_parallel_size is not None: args.virtual_pipeline_model_parallel_size is not None:
assert model_type != ModelType.encoder_and_decoder, \
"Interleaved schedule not supported for model with both encoder and decoder"
model = [] model = []
for i in range(args.virtual_pipeline_model_parallel_size): for i in range(args.virtual_pipeline_model_parallel_size):
mpu.set_virtual_pipeline_model_parallel_rank(i) mpu.set_virtual_pipeline_model_parallel_rank(i)
...@@ -206,14 +213,36 @@ def get_model(model_provider_func): ...@@ -206,14 +213,36 @@ def get_model(model_provider_func):
pre_process=pre_process, pre_process=pre_process,
post_process=post_process post_process=post_process
) )
this_model.model_type = model_type
model.append(this_model) model.append(this_model)
else: else:
pre_process = mpu.is_pipeline_first_stage() pre_process = mpu.is_pipeline_first_stage()
post_process = mpu.is_pipeline_last_stage() post_process = mpu.is_pipeline_last_stage()
model = model_provider_func( add_encoder = True
pre_process=pre_process, add_decoder = True
post_process=post_process if model_type == ModelType.encoder_and_decoder:
) if mpu.get_pipeline_model_parallel_world_size() > 1:
assert args.pipeline_model_parallel_split_rank is not None, \
"Split rank needs to be specified for model with both encoder and decoder"
rank = mpu.get_pipeline_model_parallel_rank()
split_rank = args.pipeline_model_parallel_split_rank
world_size = mpu.get_pipeline_model_parallel_world_size()
pre_process = rank == 0 or rank == split_rank
post_process = (rank == (split_rank - 1)) or (
rank == (world_size - 1))
add_encoder = mpu.is_pipeline_stage_before_split()
add_decoder = mpu.is_pipeline_stage_after_split()
model = model_provider_func(
pre_process=pre_process,
post_process=post_process,
add_encoder=add_encoder,
add_decoder=add_decoder)
else:
model = model_provider_func(
pre_process=pre_process,
post_process=post_process
)
model.model_type = model_type
if not isinstance(model, list): if not isinstance(model, list):
model = [model] model = [model]
...@@ -304,11 +333,11 @@ def get_learning_rate_scheduler(optimizer): ...@@ -304,11 +333,11 @@ def get_learning_rate_scheduler(optimizer):
return lr_scheduler return lr_scheduler
def setup_model_and_optimizer(model_provider_func): def setup_model_and_optimizer(model_provider_func, model_type):
"""Setup model and optimizer.""" """Setup model and optimizer."""
args = get_args() args = get_args()
model = get_model(model_provider_func) model = get_model(model_provider_func, model_type)
unwrapped_model = unwrap_model(model, unwrapped_model = unwrap_model(model,
(torchDDP, LocalDDP, Float16Module)) (torchDDP, LocalDDP, Float16Module))
...@@ -374,13 +403,14 @@ def train_step(forward_step_func, data_iterator, ...@@ -374,13 +403,14 @@ def train_step(forward_step_func, data_iterator,
# This should only run for models that support pipelined model parallelism # This should only run for models that support pipelined model parallelism
# (BERT and GPT-2). # (BERT and GPT-2).
timers('backward-embedding-all-reduce').start() timers('backward-embedding-all-reduce').start()
if (mpu.is_pipeline_first_stage(ignore_virtual=True) or if mpu.is_rank_in_embedding_group(ignore_virtual=True) and \
mpu.is_pipeline_last_stage(ignore_virtual=True)) and \
mpu.get_pipeline_model_parallel_world_size() > 1: mpu.get_pipeline_model_parallel_world_size() > 1:
if mpu.is_pipeline_first_stage(ignore_virtual=True): if mpu.is_pipeline_first_stage(ignore_virtual=True):
unwrapped_model = model[0] unwrapped_model = model[0]
elif mpu.is_pipeline_last_stage(ignore_virtual=True): elif mpu.is_pipeline_last_stage(ignore_virtual=True):
unwrapped_model = model[-1] unwrapped_model = model[-1]
else: # We do not support the interleaved schedule for T5 yet.
unwrapped_model = model[0]
unwrapped_model = unwrap_model( unwrapped_model = unwrap_model(
unwrapped_model, (torchDDP, LocalDDP, Float16Module)) unwrapped_model, (torchDDP, LocalDDP, Float16Module))
......
...@@ -25,7 +25,7 @@ from megatron import print_rank_0 ...@@ -25,7 +25,7 @@ from megatron import print_rank_0
from megatron import get_timers from megatron import get_timers
from megatron import mpu from megatron import mpu
from megatron.data.dataset_utils import build_train_valid_test_datasets from megatron.data.dataset_utils import build_train_valid_test_datasets
from megatron.model import BertModel from megatron.model import BertModel, ModelType
from megatron.training import pretrain from megatron.training import pretrain
from megatron.utils import average_losses_across_data_parallel_group from megatron.utils import average_losses_across_data_parallel_group
...@@ -143,5 +143,6 @@ def train_valid_test_datasets_provider(train_val_test_num_samples): ...@@ -143,5 +143,6 @@ def train_valid_test_datasets_provider(train_val_test_num_samples):
if __name__ == "__main__": if __name__ == "__main__":
pretrain(train_valid_test_datasets_provider, model_provider, forward_step, pretrain(train_valid_test_datasets_provider, model_provider,
args_defaults={'tokenizer_type': 'BertWordPieceLowerCase'}) ModelType.encoder_or_decoder,
forward_step, args_defaults={'tokenizer_type': 'BertWordPieceLowerCase'})
...@@ -23,7 +23,7 @@ from megatron import get_timers ...@@ -23,7 +23,7 @@ from megatron import get_timers
from megatron import get_tokenizer from megatron import get_tokenizer
from megatron import mpu from megatron import mpu
from megatron.data.gpt_dataset import build_train_valid_test_datasets from megatron.data.gpt_dataset import build_train_valid_test_datasets
from megatron.model import GPTModel from megatron.model import GPTModel, ModelType
from megatron.training import pretrain from megatron.training import pretrain
from megatron.utils import get_ltor_masks_and_position_ids from megatron.utils import get_ltor_masks_and_position_ids
from megatron.utils import average_losses_across_data_parallel_group from megatron.utils import average_losses_across_data_parallel_group
...@@ -121,5 +121,6 @@ def train_valid_test_datasets_provider(train_val_test_num_samples): ...@@ -121,5 +121,6 @@ def train_valid_test_datasets_provider(train_val_test_num_samples):
if __name__ == "__main__": if __name__ == "__main__":
pretrain(train_valid_test_datasets_provider, model_provider, forward_step, pretrain(train_valid_test_datasets_provider, model_provider,
args_defaults={'tokenizer_type': 'GPT2BPETokenizer'}) ModelType.encoder_or_decoder,
forward_step, args_defaults={'tokenizer_type': 'GPT2BPETokenizer'})
...@@ -28,6 +28,7 @@ from megatron import get_timers ...@@ -28,6 +28,7 @@ from megatron import get_timers
from megatron import mpu from megatron import mpu
from megatron.data.biencoder_dataset_utils import get_ict_batch from megatron.data.biencoder_dataset_utils import get_ict_batch
from megatron.data.dataset_utils import build_train_valid_test_datasets from megatron.data.dataset_utils import build_train_valid_test_datasets
from megatron.model import ModelType
from megatron.model.biencoder_model import biencoder_model_provider from megatron.model.biencoder_model import biencoder_model_provider
from megatron.training import pretrain from megatron.training import pretrain
from megatron.utils import average_losses_across_data_parallel_group from megatron.utils import average_losses_across_data_parallel_group
...@@ -174,5 +175,6 @@ def train_valid_test_datasets_provider(train_val_test_num_samples): ...@@ -174,5 +175,6 @@ def train_valid_test_datasets_provider(train_val_test_num_samples):
if __name__ == "__main__": if __name__ == "__main__":
pretrain(train_valid_test_datasets_provider, pretrain(train_valid_test_datasets_provider,
pretrain_ict_model_provider, pretrain_ict_model_provider,
ModelType.encoder_or_decoder,
forward_step, forward_step,
args_defaults={'tokenizer_type': 'BertWordPieceLowerCase'}) args_defaults={'tokenizer_type': 'BertWordPieceLowerCase'})
...@@ -26,18 +26,22 @@ from megatron import ( ...@@ -26,18 +26,22 @@ from megatron import (
print_rank_0 print_rank_0
) )
from megatron.data.dataset_utils import build_train_valid_test_datasets from megatron.data.dataset_utils import build_train_valid_test_datasets
from megatron.model import T5Model from megatron.model import T5Model, ModelType
from megatron.training import pretrain from megatron.training import pretrain
from megatron.utils import average_losses_across_data_parallel_group from megatron.utils import average_losses_across_data_parallel_group
def model_provider(pre_process=True, post_process=True): def model_provider(pre_process=True, post_process=True,
add_encoder=True, add_decoder=True):
"""Build the model.""" """Build the model."""
assert pre_process and post_process, "T5 doesn't yet support pipelining"
print_rank_0('building T5 model ...') print_rank_0('building T5 model ...')
model = T5Model(num_tokentypes=0, model = T5Model(num_tokentypes=0,
parallel_output=True) parallel_output=True,
pre_process=pre_process,
post_process=post_process,
add_encoder=add_encoder,
add_decoder=add_decoder)
return model return model
...@@ -70,9 +74,7 @@ def get_batch(data_iterator): ...@@ -70,9 +74,7 @@ def get_batch(data_iterator):
def loss_func(loss_mask, output_tensor): def loss_func(loss_mask, output_tensor):
lm_loss_, _ = output_tensor lm_loss_ = output_tensor.float()
lm_loss_ = lm_loss_.float()
lm_loss = torch.sum( lm_loss = torch.sum(
lm_loss_.view(-1) * loss_mask.reshape(-1)) / loss_mask.sum() lm_loss_.view(-1) * loss_mask.reshape(-1)) / loss_mask.sum()
...@@ -130,5 +132,5 @@ def train_valid_test_datasets_provider(train_val_test_num_samples): ...@@ -130,5 +132,5 @@ def train_valid_test_datasets_provider(train_val_test_num_samples):
if __name__ == "__main__": if __name__ == "__main__":
pretrain(train_valid_test_datasets_provider, model_provider, forward_step, pretrain(train_valid_test_datasets_provider, model_provider, ModelType.encoder_and_decoder,
args_defaults={'tokenizer_type': 'BertWordPieceLowerCase'}) forward_step, args_defaults={'tokenizer_type': 'BertWordPieceLowerCase'})
...@@ -20,6 +20,7 @@ import torch.nn.functional as F ...@@ -20,6 +20,7 @@ import torch.nn.functional as F
from functools import partial from functools import partial
from megatron import get_args, get_timers, mpu, print_rank_0 from megatron import get_args, get_timers, mpu, print_rank_0
from megatron.data.vit_dataset import build_train_valid_datasets from megatron.data.vit_dataset import build_train_valid_datasets
from megatron.model import ModelType
from megatron.model.vit_model import VitModel from megatron.model.vit_model import VitModel
from megatron.training import pretrain from megatron.training import pretrain
from megatron.utils import average_losses_across_data_parallel_group from megatron.utils import average_losses_across_data_parallel_group
...@@ -92,6 +93,7 @@ if __name__ == "__main__": ...@@ -92,6 +93,7 @@ if __name__ == "__main__":
pretrain( pretrain(
train_valid_test_datasets_provider, train_valid_test_datasets_provider,
model_provider, model_provider,
ModelType.encoder_or_decoder,
forward_step, forward_step,
args_defaults={'dataloader_type': 'cyclic'} args_defaults={'dataloader_type': 'cyclic'}
) )
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment