"...lm-evaluation-harness.git" did not exist on "55749b9b4535ec1bf9d325f29cddd1fbc1d718ed"
Commit 808bb8da authored by thomwolf's avatar thomwolf
Browse files

fix transfo xl tests

parent b016dd16
......@@ -125,6 +125,11 @@ class CommonTestCases:
def test_attention_outputs(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
decoder_seq_length = self.model_tester.decoder_seq_length if hasattr(self.model_tester, 'decoder_seq_length') else self.model_tester.seq_length
encoder_seq_length = self.model_tester.encoder_seq_length if hasattr(self.model_tester, 'encoder_seq_length') else self.model_tester.seq_length
decoder_key_length = self.model_tester.key_length if hasattr(self.model_tester, 'key_length') else decoder_seq_length
encoder_key_length = self.model_tester.key_length if hasattr(self.model_tester, 'key_length') else encoder_seq_length
for model_class in self.all_model_classes:
config.output_attentions = True
config.output_hidden_states = False
......@@ -138,8 +143,8 @@ class CommonTestCases:
self.assertListEqual(
list(attentions[0].shape[-3:]),
[self.model_tester.num_attention_heads,
self.model_tester.encoder_seq_length if hasattr(self.model_tester, 'encoder_seq_length') else self.model_tester.seq_length,
self.model_tester.encoder_seq_length if hasattr(self.model_tester, 'encoder_seq_length') else self.model_tester.seq_length])
encoder_seq_length ,
encoder_key_length])
out_len = len(outputs)
if self.is_encoder_decoder:
......@@ -151,8 +156,9 @@ class CommonTestCases:
self.assertListEqual(
list(decoder_attentions[0].shape[-3:]),
[self.model_tester.num_attention_heads,
self.model_tester.decoder_seq_length if hasattr(self.model_tester, 'decoder_seq_length') else self.model_tester.seq_length,
self.model_tester.decoder_seq_length if hasattr(self.model_tester, 'decoder_seq_length') else self.model_tester.seq_length])
decoder_seq_length,
decoder_key_length
])
# Check attention is always last and order is fine
config.output_attentions = True
......@@ -169,8 +175,8 @@ class CommonTestCases:
self.assertListEqual(
list(self_attentions[0].shape[-3:]),
[self.model_tester.num_attention_heads,
self.model_tester.encoder_seq_length if hasattr(self.model_tester, 'encoder_seq_length') else self.model_tester.seq_length,
self.model_tester.encoder_seq_length if hasattr(self.model_tester, 'encoder_seq_length') else self.model_tester.seq_length])
encoder_seq_length,
encoder_key_length])
def test_torchscript(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
......
......@@ -68,7 +68,7 @@ class TFTransfoXLModelTest(TFCommonTestCases.TFCommonModelTester):
self.batch_size = batch_size
self.seq_length = seq_length
self.mem_len = mem_len
self.key_len = seq_length + mem_len
self.key_length = seq_length + mem_len
self.clamp_len = clamp_len
self.is_training = is_training
self.use_labels = use_labels
......
......@@ -66,7 +66,7 @@ class TransfoXLModelTest(CommonTestCases.CommonModelTester):
self.batch_size = batch_size
self.seq_length = seq_length
self.mem_len = mem_len
self.key_len = seq_length + mem_len
self.key_length = seq_length + mem_len
self.clamp_len = clamp_len
self.is_training = is_training
self.use_labels = use_labels
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment