Commit 4b3519cb authored by Vijay Korthikanti's avatar Vijay Korthikanti
Browse files

address review comments

parent 834d6dd5
...@@ -316,11 +316,10 @@ class TransformerLanguageModelBase(MegatronModule): ...@@ -316,11 +316,10 @@ class TransformerLanguageModelBase(MegatronModule):
self_attn_mask_type=self_attn_mask_type) self_attn_mask_type=self_attn_mask_type)
self._encoder_key = 'encoder' self._encoder_key = 'encoder'
# assuming pooler and decoder are in the last stage # Decoder
# of the pipeline(to be revised)
if mpu.is_pipeline_last_stage():
# decoder
if self.add_decoder: if self.add_decoder:
assert args.pipeline_model_parallel_size == 1, \
'pipeline parallelism is not supported in the presence of decoder'
self.decoder = ParallelTransformer( self.decoder = ParallelTransformer(
attention_mask_func, attention_mask_func,
self.init_method, self.init_method,
...@@ -329,6 +328,7 @@ class TransformerLanguageModelBase(MegatronModule): ...@@ -329,6 +328,7 @@ class TransformerLanguageModelBase(MegatronModule):
self_attn_mask_type=AttnMaskType.causal) self_attn_mask_type=AttnMaskType.causal)
self._decoder_key = 'decoder' self._decoder_key = 'decoder'
if mpu.is_pipeline_last_stage():
# Pooler. # Pooler.
if self.add_pooler: if self.add_pooler:
self.pooler = Pooler(self.hidden_size, self.init_method) self.pooler = Pooler(self.hidden_size, self.init_method)
...@@ -367,7 +367,7 @@ class TransformerLanguageModelBase(MegatronModule): ...@@ -367,7 +367,7 @@ class TransformerLanguageModelBase(MegatronModule):
# output. For example, it is helpful to compute # output. For example, it is helpful to compute
# similarity between two sequences by average pooling # similarity between two sequences by average pooling
if not self.add_decoder or output_enc_hidden: if not self.add_decoder or output_enc_hidden:
if self.add_pooler: if self.add_pooler and mpu.is_pipeline_last_stage():
return encoder_output, pooled_output return encoder_output, pooled_output
else: else:
return encoder_output return encoder_output
...@@ -384,13 +384,11 @@ class TransformerLanguageModelBase(MegatronModule): ...@@ -384,13 +384,11 @@ class TransformerLanguageModelBase(MegatronModule):
encoder_output=encoder_output, encoder_output=encoder_output,
enc_dec_attn_mask=enc_dec_attn_mask) enc_dec_attn_mask=enc_dec_attn_mask)
if self.add_pooler: if self.add_pooler and mpu.is_pipeline_last_stage():
return decoder_output, encoder_output, pooled_output return decoder_output, encoder_output, pooled_output
else: else:
return decoder_output, encoder_output return decoder_output, encoder_output
return encoder_output
def state_dict_for_save_checkpoint(self, destination=None, prefix='', def state_dict_for_save_checkpoint(self, destination=None, prefix='',
keep_vars=False): keep_vars=False):
"""For easy load.""" """For easy load."""
...@@ -577,30 +575,21 @@ class TransformerLanguageModelLastStage(TransformerLanguageModelBase): ...@@ -577,30 +575,21 @@ class TransformerLanguageModelLastStage(TransformerLanguageModelBase):
init_method, init_method,
output_layer_init_method, output_layer_init_method,
self_attn_mask_type=AttnMaskType.padding, self_attn_mask_type=AttnMaskType.padding,
add_decoder=False,
add_pooler=False): add_pooler=False):
super(TransformerLanguageModelLastStage, self).__init__( super(TransformerLanguageModelLastStage, self).__init__(
attention_mask_func, attention_mask_func,
init_method, init_method,
output_layer_init_method, output_layer_init_method,
self_attn_mask_type=AttnMaskType.padding, self_attn_mask_type=AttnMaskType.padding,
add_decoder=add_decoder,
add_pooler=add_pooler) add_pooler=add_pooler)
def forward(self, hidden_states, enc_attention_mask, def forward(self, hidden_states, attention_mask,
dec_input_ids=None, dec_position_ids=None, dec_attn_mask=None, layer_past=None, get_key_value=False,
enc_dec_attn_mask=None, layer_past=None, get_key_value=False, pooling_sequence_index=0):
pooling_sequence_index=0, enc_hidden_states=None,
output_enc_hidden=False):
return super(TransformerLanguageModelLastStage, self).forward( return super(TransformerLanguageModelLastStage, self).forward(
hidden_states, hidden_states,
enc_attention_mask, attention_mask,
dec_language_model_input=(dec_input_ids, dec_position_ids),
dec_attn_mask=dec_attn_mask,
enc_dec_attn_mask=enc_dec_attn_mask,
layer_past=layer_past, layer_past=layer_past,
get_key_value=get_key_value, get_key_value=get_key_value,
pooling_sequence_index=pooling_sequence_index, pooling_sequence_index=pooling_sequence_index,
enc_hidden_states=enc_hidden_states,
ouput_enc_hidden=output_enc_hidden
) )
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment