Unverified Commit 18c8cf00 authored by Yossi Synett's avatar Yossi Synett Committed by GitHub
Browse files

Fix bug in x-attentions output for roberta and harden test to catch it (#8660)

parent 48cc2247
...@@ -814,7 +814,7 @@ class RobertaForCausalLM(RobertaPreTrainedModel): ...@@ -814,7 +814,7 @@ class RobertaForCausalLM(RobertaPreTrainedModel):
logits=prediction_scores, logits=prediction_scores,
hidden_states=outputs.hidden_states, hidden_states=outputs.hidden_states,
attentions=outputs.attentions, attentions=outputs.attentions,
cross_attentions=outputs.attentions, cross_attentions=outputs.cross_attentions,
) )
def prepare_inputs_for_generation(self, input_ids, attention_mask=None, **model_kwargs): def prepare_inputs_for_generation(self, input_ids, attention_mask=None, **model_kwargs):
......
...@@ -300,6 +300,9 @@ class EncoderDecoderMixin: ...@@ -300,6 +300,9 @@ class EncoderDecoderMixin:
labels, labels,
**kwargs **kwargs
): ):
# make the decoder inputs a different shape from the encoder inputs to harden the test
decoder_input_ids = decoder_input_ids[:, :-1]
decoder_attention_mask = decoder_attention_mask[:, :-1]
encoder_model, decoder_model = self.get_encoder_decoder_model(config, decoder_config) encoder_model, decoder_model = self.get_encoder_decoder_model(config, decoder_config)
enc_dec_model = EncoderDecoderModel(encoder=encoder_model, decoder=decoder_model) enc_dec_model = EncoderDecoderModel(encoder=encoder_model, decoder=decoder_model)
enc_dec_model.to(torch_device) enc_dec_model.to(torch_device)
...@@ -314,9 +317,8 @@ class EncoderDecoderMixin: ...@@ -314,9 +317,8 @@ class EncoderDecoderMixin:
encoder_attentions = outputs_encoder_decoder["encoder_attentions"] encoder_attentions = outputs_encoder_decoder["encoder_attentions"]
self.assertEqual(len(encoder_attentions), config.num_hidden_layers) self.assertEqual(len(encoder_attentions), config.num_hidden_layers)
self.assertListEqual( self.assertEqual(
list(encoder_attentions[0].shape[-3:]), encoder_attentions[0].shape[-3:], (config.num_attention_heads, input_ids.shape[-1], input_ids.shape[-1])
[config.num_attention_heads, input_ids.shape[-1], input_ids.shape[-1]],
) )
decoder_attentions = outputs_encoder_decoder["decoder_attentions"] decoder_attentions = outputs_encoder_decoder["decoder_attentions"]
...@@ -327,20 +329,20 @@ class EncoderDecoderMixin: ...@@ -327,20 +329,20 @@ class EncoderDecoderMixin:
) )
self.assertEqual(len(decoder_attentions), num_decoder_layers) self.assertEqual(len(decoder_attentions), num_decoder_layers)
self.assertListEqual( self.assertEqual(
list(decoder_attentions[0].shape[-3:]), decoder_attentions[0].shape[-3:],
[decoder_config.num_attention_heads, decoder_input_ids.shape[-1], decoder_input_ids.shape[-1]], (decoder_config.num_attention_heads, decoder_input_ids.shape[-1], decoder_input_ids.shape[-1]),
) )
cross_attentions = outputs_encoder_decoder["cross_attentions"] cross_attentions = outputs_encoder_decoder["cross_attentions"]
self.assertEqual(len(cross_attentions), num_decoder_layers) self.assertEqual(len(cross_attentions), num_decoder_layers)
cross_attention_input_seq_len = input_ids.shape[-1] * ( cross_attention_input_seq_len = decoder_input_ids.shape[-1] * (
1 + (decoder_config.ngram if hasattr(decoder_config, "ngram") else 0) 1 + (decoder_config.ngram if hasattr(decoder_config, "ngram") else 0)
) )
self.assertListEqual( self.assertEqual(
list(cross_attentions[0].shape[-3:]), cross_attentions[0].shape[-3:],
[decoder_config.num_attention_heads, cross_attention_input_seq_len, decoder_input_ids.shape[-1]], (decoder_config.num_attention_heads, cross_attention_input_seq_len, input_ids.shape[-1]),
) )
def check_encoder_decoder_model_generate(self, input_ids, config, decoder_config, **kwargs): def check_encoder_decoder_model_generate(self, input_ids, config, decoder_config, **kwargs):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment