Commit 8e651f56 authored by thomwolf's avatar thomwolf
Browse files

fix tf tests

parent 808bb8da
......@@ -213,6 +213,11 @@ class TFCommonTestCases:
def test_attention_outputs(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
decoder_seq_length = self.model_tester.decoder_seq_length if hasattr(self.model_tester, 'decoder_seq_length') else self.model_tester.seq_length
encoder_seq_length = self.model_tester.encoder_seq_length if hasattr(self.model_tester, 'encoder_seq_length') else self.model_tester.seq_length
decoder_key_length = self.model_tester.key_length if hasattr(self.model_tester, 'key_length') else decoder_seq_length
encoder_key_length = self.model_tester.key_length if hasattr(self.model_tester, 'key_length') else encoder_seq_length
for model_class in self.all_model_classes:
config.output_attentions = True
config.output_hidden_states = False
......@@ -225,8 +230,8 @@ class TFCommonTestCases:
self.assertListEqual(
list(attentions[0].shape[-3:]),
[self.model_tester.num_attention_heads,
self.model_tester.seq_length,
self.model_tester.key_len if hasattr(self.model_tester, 'key_len') else self.model_tester.seq_length])
encoder_seq_length,
encoder_key_length])
out_len = len(outputs)
if self.is_encoder_decoder:
......@@ -238,8 +243,8 @@ class TFCommonTestCases:
self.assertListEqual(
list(decoder_attentions[0].shape[-3:]),
[self.model_tester.num_attention_heads,
self.model_tester.seq_length,
self.model_tester.key_len if hasattr(self.model_tester, 'key_len') else self.model_tester.seq_length])
decoder_seq_length,
decoder_key_length])
# Check attention is always last and order is fine
config.output_attentions = True
......@@ -255,8 +260,8 @@ class TFCommonTestCases:
self.assertListEqual(
list(attentions[0].shape[-3:]),
[self.model_tester.num_attention_heads,
self.model_tester.seq_length,
self.model_tester.key_len if hasattr(self.model_tester, 'key_len') else self.model_tester.seq_length])
encoder_seq_length,
encoder_key_length])
def test_headmasking(self):
pass
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment