Commit cbcb83f2 authored by sshleifer's avatar sshleifer Committed by Lysandre Debut
Browse files

minor cleanup of test_attention_outputs

parent 3bf54172
...@@ -117,23 +117,11 @@ class ModelTesterMixin: ...@@ -117,23 +117,11 @@ class ModelTesterMixin:
def test_attention_outputs(self): def test_attention_outputs(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
seq_len = self.model_tester.seq_length
decoder_seq_length = ( decoder_seq_length = getattr(self.model_tester, 'decoder_seq_length', seq_len)
self.model_tester.decoder_seq_length encoder_seq_length = getattr(self.model_tester, 'encoder_seq_length', seq_len)
if hasattr(self.model_tester, "decoder_seq_length") decoder_key_length = getattr(self.model_tester, "key_length", decoder_seq_length)
else self.model_tester.seq_length encoder_key_length = getattr(self.model_tester, "key_length", encoder_seq_length)
)
encoder_seq_length = (
self.model_tester.encoder_seq_length
if hasattr(self.model_tester, "encoder_seq_length")
else self.model_tester.seq_length
)
decoder_key_length = (
self.model_tester.key_length if hasattr(self.model_tester, "key_length") else decoder_seq_length
)
encoder_key_length = (
self.model_tester.key_length if hasattr(self.model_tester, "key_length") else encoder_seq_length
)
for model_class in self.all_model_classes: for model_class in self.all_model_classes:
config.output_attentions = True config.output_attentions = True
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment