Commit 808bb8da authored by thomwolf's avatar thomwolf
Browse files

fix transfo xl tests

parent b016dd16
...@@ -125,6 +125,11 @@ class CommonTestCases: ...@@ -125,6 +125,11 @@ class CommonTestCases:
def test_attention_outputs(self): def test_attention_outputs(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
decoder_seq_length = self.model_tester.decoder_seq_length if hasattr(self.model_tester, 'decoder_seq_length') else self.model_tester.seq_length
encoder_seq_length = self.model_tester.encoder_seq_length if hasattr(self.model_tester, 'encoder_seq_length') else self.model_tester.seq_length
decoder_key_length = self.model_tester.key_length if hasattr(self.model_tester, 'key_length') else decoder_seq_length
encoder_key_length = self.model_tester.key_length if hasattr(self.model_tester, 'key_length') else encoder_seq_length
for model_class in self.all_model_classes: for model_class in self.all_model_classes:
config.output_attentions = True config.output_attentions = True
config.output_hidden_states = False config.output_hidden_states = False
...@@ -138,8 +143,8 @@ class CommonTestCases: ...@@ -138,8 +143,8 @@ class CommonTestCases:
self.assertListEqual( self.assertListEqual(
list(attentions[0].shape[-3:]), list(attentions[0].shape[-3:]),
[self.model_tester.num_attention_heads, [self.model_tester.num_attention_heads,
self.model_tester.encoder_seq_length if hasattr(self.model_tester, 'encoder_seq_length') else self.model_tester.seq_length, encoder_seq_length ,
self.model_tester.encoder_seq_length if hasattr(self.model_tester, 'encoder_seq_length') else self.model_tester.seq_length]) encoder_key_length])
out_len = len(outputs) out_len = len(outputs)
if self.is_encoder_decoder: if self.is_encoder_decoder:
...@@ -151,8 +156,9 @@ class CommonTestCases: ...@@ -151,8 +156,9 @@ class CommonTestCases:
self.assertListEqual( self.assertListEqual(
list(decoder_attentions[0].shape[-3:]), list(decoder_attentions[0].shape[-3:]),
[self.model_tester.num_attention_heads, [self.model_tester.num_attention_heads,
self.model_tester.decoder_seq_length if hasattr(self.model_tester, 'decoder_seq_length') else self.model_tester.seq_length, decoder_seq_length,
self.model_tester.decoder_seq_length if hasattr(self.model_tester, 'decoder_seq_length') else self.model_tester.seq_length]) decoder_key_length
])
# Check attention is always last and order is fine # Check attention is always last and order is fine
config.output_attentions = True config.output_attentions = True
...@@ -169,8 +175,8 @@ class CommonTestCases: ...@@ -169,8 +175,8 @@ class CommonTestCases:
self.assertListEqual( self.assertListEqual(
list(self_attentions[0].shape[-3:]), list(self_attentions[0].shape[-3:]),
[self.model_tester.num_attention_heads, [self.model_tester.num_attention_heads,
self.model_tester.encoder_seq_length if hasattr(self.model_tester, 'encoder_seq_length') else self.model_tester.seq_length, encoder_seq_length,
self.model_tester.encoder_seq_length if hasattr(self.model_tester, 'encoder_seq_length') else self.model_tester.seq_length]) encoder_key_length])
def test_torchscript(self): def test_torchscript(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
......
...@@ -68,7 +68,7 @@ class TFTransfoXLModelTest(TFCommonTestCases.TFCommonModelTester): ...@@ -68,7 +68,7 @@ class TFTransfoXLModelTest(TFCommonTestCases.TFCommonModelTester):
self.batch_size = batch_size self.batch_size = batch_size
self.seq_length = seq_length self.seq_length = seq_length
self.mem_len = mem_len self.mem_len = mem_len
self.key_len = seq_length + mem_len self.key_length = seq_length + mem_len
self.clamp_len = clamp_len self.clamp_len = clamp_len
self.is_training = is_training self.is_training = is_training
self.use_labels = use_labels self.use_labels = use_labels
......
...@@ -66,7 +66,7 @@ class TransfoXLModelTest(CommonTestCases.CommonModelTester): ...@@ -66,7 +66,7 @@ class TransfoXLModelTest(CommonTestCases.CommonModelTester):
self.batch_size = batch_size self.batch_size = batch_size
self.seq_length = seq_length self.seq_length = seq_length
self.mem_len = mem_len self.mem_len = mem_len
self.key_len = seq_length + mem_len self.key_length = seq_length + mem_len
self.clamp_len = clamp_len self.clamp_len = clamp_len
self.is_training = is_training self.is_training = is_training
self.use_labels = use_labels self.use_labels = use_labels
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment