Commit 4b3519cb authored by Vijay Korthikanti's avatar Vijay Korthikanti
Browse files

address review comments

parent 834d6dd5
......@@ -316,19 +316,19 @@ class TransformerLanguageModelBase(MegatronModule):
self_attn_mask_type=self_attn_mask_type)
self._encoder_key = 'encoder'
# assuming pooler and decoder are in the last stage
# of the pipeline(to be revised)
if mpu.is_pipeline_last_stage():
# decoder
if self.add_decoder:
self.decoder = ParallelTransformer(
attention_mask_func,
self.init_method,
output_layer_init_method,
layer_type=LayerType.decoder,
self_attn_mask_type=AttnMaskType.causal)
self._decoder_key = 'decoder'
# Decoder
if self.add_decoder:
assert args.pipeline_model_parallel_size == 1, \
'pipeline parallelism is not supported in the presence of decoder'
self.decoder = ParallelTransformer(
attention_mask_func,
self.init_method,
output_layer_init_method,
layer_type=LayerType.decoder,
self_attn_mask_type=AttnMaskType.causal)
self._decoder_key = 'decoder'
if mpu.is_pipeline_last_stage():
# Pooler.
if self.add_pooler:
self.pooler = Pooler(self.hidden_size, self.init_method)
......@@ -363,33 +363,31 @@ class TransformerLanguageModelBase(MegatronModule):
pooled_output = self.pooler(encoder_output,
pooling_sequence_index)
# output_enc_hidden refers to when we just need the encoder's
# output. For example, it is helpful to compute
# similarity between two sequences by average pooling
if not self.add_decoder or output_enc_hidden:
if self.add_pooler:
return encoder_output, pooled_output
else:
return encoder_output
# Decoder Embedding
(dec_input_ids, dec_position_ids) = dec_language_model_input
dec_embedding_output = self.embedding(dec_input_ids,
dec_position_ids)
# decoder
decoder_output = self.decoder(dec_embedding_output,
dec_attn_mask,
layer_past=layer_past,
get_key_value=get_key_value,
encoder_output=encoder_output,
enc_dec_attn_mask=enc_dec_attn_mask)
if self.add_pooler:
return decoder_output, encoder_output, pooled_output
# output_enc_hidden refers to when we just need the encoder's
# output. For example, it is helpful to compute
# similarity between two sequences by average pooling
if not self.add_decoder or output_enc_hidden:
if self.add_pooler and mpu.is_pipeline_last_stage():
return encoder_output, pooled_output
else:
return decoder_output, encoder_output
return encoder_output
return encoder_output
# Decoder Embedding
(dec_input_ids, dec_position_ids) = dec_language_model_input
dec_embedding_output = self.embedding(dec_input_ids,
dec_position_ids)
# decoder
decoder_output = self.decoder(dec_embedding_output,
dec_attn_mask,
layer_past=layer_past,
get_key_value=get_key_value,
encoder_output=encoder_output,
enc_dec_attn_mask=enc_dec_attn_mask)
if self.add_pooler and mpu.is_pipeline_last_stage():
return decoder_output, encoder_output, pooled_output
else:
return decoder_output, encoder_output
def state_dict_for_save_checkpoint(self, destination=None, prefix='',
keep_vars=False):
......@@ -462,12 +460,12 @@ class TransformerLanguageModelBase(MegatronModule):
'could not find data for pooler in the checkpoint'
self.pooler.load_state_dict(state_dict[self._pooler_key],
strict=strict)
# decoder
if self.add_decoder:
assert 'decoder' in state_dict, \
'could not find data for pooler in the checkpoint'
self.decoder.load_state_dict(state_dict[self._decoder_key],
strict=strict)
# decoder
if self.add_decoder:
assert 'decoder' in state_dict, \
'could not find data for pooler in the checkpoint'
self.decoder.load_state_dict(state_dict[self._decoder_key],
strict=strict)
class TransformerLanguageModel(TransformerLanguageModelBase):
......@@ -577,30 +575,21 @@ class TransformerLanguageModelLastStage(TransformerLanguageModelBase):
init_method,
output_layer_init_method,
self_attn_mask_type=AttnMaskType.padding,
add_decoder=False,
add_pooler=False):
super(TransformerLanguageModelLastStage, self).__init__(
attention_mask_func,
init_method,
output_layer_init_method,
self_attn_mask_type=AttnMaskType.padding,
add_decoder=add_decoder,
add_pooler=add_pooler)
def forward(self, hidden_states, enc_attention_mask,
dec_input_ids=None, dec_position_ids=None, dec_attn_mask=None,
enc_dec_attn_mask=None, layer_past=None, get_key_value=False,
pooling_sequence_index=0, enc_hidden_states=None,
output_enc_hidden=False):
def forward(self, hidden_states, attention_mask,
layer_past=None, get_key_value=False,
pooling_sequence_index=0):
return super(TransformerLanguageModelLastStage, self).forward(
hidden_states,
enc_attention_mask,
dec_language_model_input=(dec_input_ids, dec_position_ids),
dec_attn_mask=dec_attn_mask,
enc_dec_attn_mask=enc_dec_attn_mask,
attention_mask,
layer_past=layer_past,
get_key_value=get_key_value,
pooling_sequence_index=pooling_sequence_index,
enc_hidden_states=enc_hidden_states,
ouput_enc_hidden=output_enc_hidden
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment