Commit 4b3519cb authored by Vijay Korthikanti's avatar Vijay Korthikanti
Browse files

address review comments

parent 834d6dd5
...@@ -316,19 +316,19 @@ class TransformerLanguageModelBase(MegatronModule): ...@@ -316,19 +316,19 @@ class TransformerLanguageModelBase(MegatronModule):
self_attn_mask_type=self_attn_mask_type) self_attn_mask_type=self_attn_mask_type)
self._encoder_key = 'encoder' self._encoder_key = 'encoder'
# assuming pooler and decoder are in the last stage # Decoder
# of the pipeline(to be revised) if self.add_decoder:
if mpu.is_pipeline_last_stage(): assert args.pipeline_model_parallel_size == 1, \
# decoder 'pipeline parallelism is not supported in the presence of decoder'
if self.add_decoder: self.decoder = ParallelTransformer(
self.decoder = ParallelTransformer( attention_mask_func,
attention_mask_func, self.init_method,
self.init_method, output_layer_init_method,
output_layer_init_method, layer_type=LayerType.decoder,
layer_type=LayerType.decoder, self_attn_mask_type=AttnMaskType.causal)
self_attn_mask_type=AttnMaskType.causal) self._decoder_key = 'decoder'
self._decoder_key = 'decoder'
if mpu.is_pipeline_last_stage():
# Pooler. # Pooler.
if self.add_pooler: if self.add_pooler:
self.pooler = Pooler(self.hidden_size, self.init_method) self.pooler = Pooler(self.hidden_size, self.init_method)
...@@ -363,33 +363,31 @@ class TransformerLanguageModelBase(MegatronModule): ...@@ -363,33 +363,31 @@ class TransformerLanguageModelBase(MegatronModule):
pooled_output = self.pooler(encoder_output, pooled_output = self.pooler(encoder_output,
pooling_sequence_index) pooling_sequence_index)
# output_enc_hidden refers to when we just need the encoder's # output_enc_hidden refers to when we just need the encoder's
# output. For example, it is helpful to compute # output. For example, it is helpful to compute
# similarity between two sequences by average pooling # similarity between two sequences by average pooling
if not self.add_decoder or output_enc_hidden: if not self.add_decoder or output_enc_hidden:
if self.add_pooler: if self.add_pooler and mpu.is_pipeline_last_stage():
return encoder_output, pooled_output return encoder_output, pooled_output
else:
return encoder_output
# Decoder Embedding
(dec_input_ids, dec_position_ids) = dec_language_model_input
dec_embedding_output = self.embedding(dec_input_ids,
dec_position_ids)
# decoder
decoder_output = self.decoder(dec_embedding_output,
dec_attn_mask,
layer_past=layer_past,
get_key_value=get_key_value,
encoder_output=encoder_output,
enc_dec_attn_mask=enc_dec_attn_mask)
if self.add_pooler:
return decoder_output, encoder_output, pooled_output
else: else:
return decoder_output, encoder_output return encoder_output
return encoder_output # Decoder Embedding
(dec_input_ids, dec_position_ids) = dec_language_model_input
dec_embedding_output = self.embedding(dec_input_ids,
dec_position_ids)
# decoder
decoder_output = self.decoder(dec_embedding_output,
dec_attn_mask,
layer_past=layer_past,
get_key_value=get_key_value,
encoder_output=encoder_output,
enc_dec_attn_mask=enc_dec_attn_mask)
if self.add_pooler and mpu.is_pipeline_last_stage():
return decoder_output, encoder_output, pooled_output
else:
return decoder_output, encoder_output
def state_dict_for_save_checkpoint(self, destination=None, prefix='', def state_dict_for_save_checkpoint(self, destination=None, prefix='',
keep_vars=False): keep_vars=False):
...@@ -462,12 +460,12 @@ class TransformerLanguageModelBase(MegatronModule): ...@@ -462,12 +460,12 @@ class TransformerLanguageModelBase(MegatronModule):
'could not find data for pooler in the checkpoint' 'could not find data for pooler in the checkpoint'
self.pooler.load_state_dict(state_dict[self._pooler_key], self.pooler.load_state_dict(state_dict[self._pooler_key],
strict=strict) strict=strict)
# decoder # decoder
if self.add_decoder: if self.add_decoder:
assert 'decoder' in state_dict, \ assert 'decoder' in state_dict, \
'could not find data for pooler in the checkpoint' 'could not find data for pooler in the checkpoint'
self.decoder.load_state_dict(state_dict[self._decoder_key], self.decoder.load_state_dict(state_dict[self._decoder_key],
strict=strict) strict=strict)
class TransformerLanguageModel(TransformerLanguageModelBase): class TransformerLanguageModel(TransformerLanguageModelBase):
...@@ -577,30 +575,21 @@ class TransformerLanguageModelLastStage(TransformerLanguageModelBase): ...@@ -577,30 +575,21 @@ class TransformerLanguageModelLastStage(TransformerLanguageModelBase):
init_method, init_method,
output_layer_init_method, output_layer_init_method,
self_attn_mask_type=AttnMaskType.padding, self_attn_mask_type=AttnMaskType.padding,
add_decoder=False,
add_pooler=False): add_pooler=False):
super(TransformerLanguageModelLastStage, self).__init__( super(TransformerLanguageModelLastStage, self).__init__(
attention_mask_func, attention_mask_func,
init_method, init_method,
output_layer_init_method, output_layer_init_method,
self_attn_mask_type=AttnMaskType.padding, self_attn_mask_type=AttnMaskType.padding,
add_decoder=add_decoder,
add_pooler=add_pooler) add_pooler=add_pooler)
def forward(self, hidden_states, enc_attention_mask, def forward(self, hidden_states, attention_mask,
dec_input_ids=None, dec_position_ids=None, dec_attn_mask=None, layer_past=None, get_key_value=False,
enc_dec_attn_mask=None, layer_past=None, get_key_value=False, pooling_sequence_index=0):
pooling_sequence_index=0, enc_hidden_states=None,
output_enc_hidden=False):
return super(TransformerLanguageModelLastStage, self).forward( return super(TransformerLanguageModelLastStage, self).forward(
hidden_states, hidden_states,
enc_attention_mask, attention_mask,
dec_language_model_input=(dec_input_ids, dec_position_ids),
dec_attn_mask=dec_attn_mask,
enc_dec_attn_mask=enc_dec_attn_mask,
layer_past=layer_past, layer_past=layer_past,
get_key_value=get_key_value, get_key_value=get_key_value,
pooling_sequence_index=pooling_sequence_index, pooling_sequence_index=pooling_sequence_index,
enc_hidden_states=enc_hidden_states,
ouput_enc_hidden=output_enc_hidden
) )
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment