Unverified Commit 8d83ebdf authored by NielsRogge's avatar NielsRogge Committed by GitHub
Browse files

[Tests] Add attentions_option to ModelTesterMixin (#15909)



* Add attentions_option to common tester

* Fix tests, apply suggestion

* Apply suggestion from code review
Co-authored-by: default avatarNiels Rogge <nielsrogge@Nielss-MacBook-Pro.local>
parent 6ce11c2c
...@@ -17,7 +17,6 @@ ...@@ -17,7 +17,6 @@
import inspect import inspect
import unittest import unittest
from typing import Dict, List, Tuple
from transformers import ConvNextConfig from transformers import ConvNextConfig
from transformers.file_utils import cached_property, is_torch_available, is_vision_available from transformers.file_utils import cached_property, is_torch_available, is_vision_available
...@@ -142,6 +141,7 @@ class ConvNextModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -142,6 +141,7 @@ class ConvNextModelTest(ModelTesterMixin, unittest.TestCase):
test_torchscript = False test_torchscript = False
test_resize_embeddings = False test_resize_embeddings = False
test_head_masking = False test_head_masking = False
has_attentions = False
def setUp(self): def setUp(self):
self.model_tester = ConvNextModelTester(self) self.model_tester = ConvNextModelTester(self)
...@@ -183,10 +183,6 @@ class ConvNextModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -183,10 +183,6 @@ class ConvNextModelTest(ModelTesterMixin, unittest.TestCase):
config_and_inputs = self.model_tester.prepare_config_and_inputs() config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_model(*config_and_inputs) self.model_tester.create_and_check_model(*config_and_inputs)
@unittest.skip(reason="Model doesn't have attention layers")
def test_attention_outputs(self):
pass
def test_hidden_states_output(self): def test_hidden_states_output(self):
def check_hidden_states_output(inputs_dict, config, model_class): def check_hidden_states_output(inputs_dict, config, model_class):
model = model_class(config) model = model_class(config)
...@@ -219,81 +215,6 @@ class ConvNextModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -219,81 +215,6 @@ class ConvNextModelTest(ModelTesterMixin, unittest.TestCase):
check_hidden_states_output(inputs_dict, config, model_class) check_hidden_states_output(inputs_dict, config, model_class)
def test_retain_grad_hidden_states_attentions(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
config.output_hidden_states = True
config.output_attentions = True
# no need to test all models as different heads yield the same functionality
model_class = self.all_model_classes[0]
model = model_class(config)
model.to(torch_device)
inputs = self._prepare_for_class(inputs_dict, model_class)
outputs = model(**inputs)
output = outputs[0]
hidden_states = outputs.hidden_states[0]
hidden_states.retain_grad()
output.flatten()[0].backward(retain_graph=True)
self.assertIsNotNone(hidden_states.grad)
def test_model_outputs_equivalence(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
def set_nan_tensor_to_zero(t):
t[t != t] = 0
return t
def check_equivalence(model, tuple_inputs, dict_inputs, additional_kwargs={}):
with torch.no_grad():
tuple_output = model(**tuple_inputs, return_dict=False, **additional_kwargs)
dict_output = model(**dict_inputs, return_dict=True, **additional_kwargs).to_tuple()
def recursive_check(tuple_object, dict_object):
if isinstance(tuple_object, (List, Tuple)):
for tuple_iterable_value, dict_iterable_value in zip(tuple_object, dict_object):
recursive_check(tuple_iterable_value, dict_iterable_value)
elif isinstance(tuple_object, Dict):
for tuple_iterable_value, dict_iterable_value in zip(
tuple_object.values(), dict_object.values()
):
recursive_check(tuple_iterable_value, dict_iterable_value)
elif tuple_object is None:
return
else:
self.assertTrue(
torch.allclose(
set_nan_tensor_to_zero(tuple_object), set_nan_tensor_to_zero(dict_object), atol=1e-5
),
msg=f"Tuple and dict output are not equal. Difference: {torch.max(torch.abs(tuple_object - dict_object))}. Tuple has `nan`: {torch.isnan(tuple_object).any()} and `inf`: {torch.isinf(tuple_object)}. Dict has `nan`: {torch.isnan(dict_object).any()} and `inf`: {torch.isinf(dict_object)}.",
)
recursive_check(tuple_output, dict_output)
for model_class in self.all_model_classes:
model = model_class(config)
model.to(torch_device)
model.eval()
tuple_inputs = self._prepare_for_class(inputs_dict, model_class)
dict_inputs = self._prepare_for_class(inputs_dict, model_class)
check_equivalence(model, tuple_inputs, dict_inputs)
tuple_inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
dict_inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
check_equivalence(model, tuple_inputs, dict_inputs)
tuple_inputs = self._prepare_for_class(inputs_dict, model_class)
dict_inputs = self._prepare_for_class(inputs_dict, model_class)
check_equivalence(model, tuple_inputs, dict_inputs, {"output_hidden_states": True})
tuple_inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
dict_inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
check_equivalence(model, tuple_inputs, dict_inputs, {"output_hidden_states": True})
def test_for_image_classification(self): def test_for_image_classification(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs() config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_for_image_classification(*config_and_inputs) self.model_tester.create_and_check_for_image_classification(*config_and_inputs)
......
...@@ -17,7 +17,6 @@ ...@@ -17,7 +17,6 @@
import inspect import inspect
import unittest import unittest
from typing import Dict, List, Tuple
from transformers import is_torch_available, is_vision_available from transformers import is_torch_available, is_vision_available
from transformers.models.auto import get_values from transformers.models.auto import get_values
...@@ -130,6 +129,7 @@ class PoolFormerModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -130,6 +129,7 @@ class PoolFormerModelTest(ModelTesterMixin, unittest.TestCase):
test_pruning = False test_pruning = False
test_resize_embeddings = False test_resize_embeddings = False
test_torchscript = False test_torchscript = False
has_attentions = False
def setUp(self): def setUp(self):
self.model_tester = PoolFormerModelTester(self) self.model_tester = PoolFormerModelTester(self)
...@@ -150,100 +150,6 @@ class PoolFormerModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -150,100 +150,6 @@ class PoolFormerModelTest(ModelTesterMixin, unittest.TestCase):
def test_model_common_attributes(self): def test_model_common_attributes(self):
pass pass
def test_retain_grad_hidden_states_attentions(self):
# Since poolformer doesn't use Attention
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
config.output_hidden_states = True
# no need to test all models as different heads yield the same functionality
model_class = self.all_model_classes[0]
model = model_class(config)
model.to(torch_device)
inputs = self._prepare_for_class(inputs_dict, model_class)
outputs = model(**inputs)
output = outputs[0]
hidden_states = outputs.hidden_states[0]
hidden_states.retain_grad()
output.flatten()[0].backward(retain_graph=True)
self.assertIsNotNone(hidden_states.grad)
def test_model_outputs_equivalence(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
def set_nan_tensor_to_zero(t):
t[t != t] = 0
return t
def check_equivalence(model, tuple_inputs, dict_inputs, additional_kwargs={}):
with torch.no_grad():
tuple_output = model(**tuple_inputs, return_dict=False, **additional_kwargs)
dict_output = model(**dict_inputs, return_dict=True, **additional_kwargs).to_tuple()
def recursive_check(tuple_object, dict_object):
if isinstance(tuple_object, (List, Tuple)):
for tuple_iterable_value, dict_iterable_value in zip(tuple_object, dict_object):
recursive_check(tuple_iterable_value, dict_iterable_value)
elif isinstance(tuple_object, Dict):
for tuple_iterable_value, dict_iterable_value in zip(
tuple_object.values(), dict_object.values()
):
recursive_check(tuple_iterable_value, dict_iterable_value)
elif tuple_object is None:
return
else:
self.assertTrue(
torch.allclose(
set_nan_tensor_to_zero(tuple_object), set_nan_tensor_to_zero(dict_object), atol=1e-5
),
msg=f"Tuple and dict output are not equal. Difference: {torch.max(torch.abs(tuple_object - dict_object))}. Tuple has `nan`: {torch.isnan(tuple_object).any()} and `inf`: {torch.isinf(tuple_object)}. Dict has `nan`: {torch.isnan(dict_object).any()} and `inf`: {torch.isinf(dict_object)}.",
)
recursive_check(tuple_output, dict_output)
for model_class in self.all_model_classes:
model = model_class(config)
model.to(torch_device)
model.eval()
tuple_inputs = self._prepare_for_class(inputs_dict, model_class)
dict_inputs = self._prepare_for_class(inputs_dict, model_class)
check_equivalence(model, tuple_inputs, dict_inputs)
tuple_inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
dict_inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
check_equivalence(model, tuple_inputs, dict_inputs)
tuple_inputs = self._prepare_for_class(inputs_dict, model_class)
dict_inputs = self._prepare_for_class(inputs_dict, model_class)
check_equivalence(model, tuple_inputs, dict_inputs, {"output_hidden_states": True})
tuple_inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
dict_inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
check_equivalence(model, tuple_inputs, dict_inputs, {"output_hidden_states": True})
def test_forward_signature(self):
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
model = model_class(config)
signature = inspect.signature(model.forward)
# signature.parameters is an OrderedDict => so arg_names order is deterministic
arg_names = [*signature.parameters.keys()]
expected_arg_names = ["pixel_values"]
self.assertListEqual(arg_names[:1], expected_arg_names)
@unittest.skip("PoolFormer does not have attention")
def test_attention_outputs(self):
pass
def test_hidden_states_output(self): def test_hidden_states_output(self):
def check_hidden_states_output(inputs_dict, config, model_class): def check_hidden_states_output(inputs_dict, config, model_class):
model = model_class(config) model = model_class(config)
...@@ -297,6 +203,18 @@ class PoolFormerModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -297,6 +203,18 @@ class PoolFormerModelTest(ModelTesterMixin, unittest.TestCase):
loss = model(**inputs).loss loss = model(**inputs).loss
loss.backward() loss.backward()
def test_forward_signature(self):
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
model = model_class(config)
signature = inspect.signature(model.forward)
# signature.parameters is an OrderedDict => so arg_names order is deterministic
arg_names = [*signature.parameters.keys()]
expected_arg_names = ["pixel_values"]
self.assertListEqual(arg_names[:1], expected_arg_names)
@slow @slow
def test_model_from_pretrained(self): def test_model_from_pretrained(self):
for model_name in POOLFORMER_PRETRAINED_MODEL_ARCHIVE_LIST[:1]: for model_name in POOLFORMER_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
......
...@@ -128,6 +128,7 @@ class ModelTesterMixin: ...@@ -128,6 +128,7 @@ class ModelTesterMixin:
test_missing_keys = True test_missing_keys = True
test_model_parallel = False test_model_parallel = False
is_encoder_decoder = False is_encoder_decoder = False
has_attentions = True
def _prepare_for_class(self, inputs_dict, model_class, return_labels=False): def _prepare_for_class(self, inputs_dict, model_class, return_labels=False):
inputs_dict = copy.deepcopy(inputs_dict) inputs_dict = copy.deepcopy(inputs_dict)
...@@ -454,119 +455,123 @@ class ModelTesterMixin: ...@@ -454,119 +455,123 @@ class ModelTesterMixin:
loss.backward() loss.backward()
def test_attention_outputs(self): def test_attention_outputs(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() if not self.has_attentions:
config.return_dict = True pass
seq_len = getattr(self.model_tester, "seq_length", None)
decoder_seq_length = getattr(self.model_tester, "decoder_seq_length", seq_len)
encoder_seq_length = getattr(self.model_tester, "encoder_seq_length", seq_len)
decoder_key_length = getattr(self.model_tester, "decoder_key_length", decoder_seq_length)
encoder_key_length = getattr(self.model_tester, "key_length", encoder_seq_length)
chunk_length = getattr(self.model_tester, "chunk_length", None)
if chunk_length is not None and hasattr(self.model_tester, "num_hashes"):
encoder_seq_length = encoder_seq_length * self.model_tester.num_hashes
for model_class in self.all_model_classes: else:
inputs_dict["output_attentions"] = True config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
inputs_dict["output_hidden_states"] = False
config.return_dict = True config.return_dict = True
model = model_class(config)
model.to(torch_device)
model.eval()
with torch.no_grad():
outputs = model(**self._prepare_for_class(inputs_dict, model_class))
attentions = outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions
self.assertEqual(len(attentions), self.model_tester.num_hidden_layers)
# check that output_attentions also work using config seq_len = getattr(self.model_tester, "seq_length", None)
del inputs_dict["output_attentions"] decoder_seq_length = getattr(self.model_tester, "decoder_seq_length", seq_len)
config.output_attentions = True encoder_seq_length = getattr(self.model_tester, "encoder_seq_length", seq_len)
model = model_class(config) decoder_key_length = getattr(self.model_tester, "decoder_key_length", decoder_seq_length)
model.to(torch_device) encoder_key_length = getattr(self.model_tester, "key_length", encoder_seq_length)
model.eval() chunk_length = getattr(self.model_tester, "chunk_length", None)
with torch.no_grad(): if chunk_length is not None and hasattr(self.model_tester, "num_hashes"):
outputs = model(**self._prepare_for_class(inputs_dict, model_class)) encoder_seq_length = encoder_seq_length * self.model_tester.num_hashes
attentions = outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions
self.assertEqual(len(attentions), self.model_tester.num_hidden_layers) for model_class in self.all_model_classes:
inputs_dict["output_attentions"] = True
if chunk_length is not None: inputs_dict["output_hidden_states"] = False
self.assertListEqual( config.return_dict = True
list(attentions[0].shape[-4:]), model = model_class(config)
[self.model_tester.num_attention_heads, encoder_seq_length, chunk_length, encoder_key_length], model.to(torch_device)
) model.eval()
else: with torch.no_grad():
self.assertListEqual( outputs = model(**self._prepare_for_class(inputs_dict, model_class))
list(attentions[0].shape[-3:]), attentions = outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions
[self.model_tester.num_attention_heads, encoder_seq_length, encoder_key_length], self.assertEqual(len(attentions), self.model_tester.num_hidden_layers)
)
out_len = len(outputs) # check that output_attentions also work using config
del inputs_dict["output_attentions"]
if self.is_encoder_decoder: config.output_attentions = True
correct_outlen = 5 model = model_class(config)
model.to(torch_device)
# loss is at first position model.eval()
if "labels" in inputs_dict: with torch.no_grad():
correct_outlen += 1 # loss is added to beginning outputs = model(**self._prepare_for_class(inputs_dict, model_class))
# Question Answering model returns start_logits and end_logits attentions = outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions
if model_class in get_values(MODEL_FOR_QUESTION_ANSWERING_MAPPING): self.assertEqual(len(attentions), self.model_tester.num_hidden_layers)
correct_outlen += 1 # start_logits and end_logits instead of only 1 output
if "past_key_values" in outputs: if chunk_length is not None:
correct_outlen += 1 # past_key_values have been returned self.assertListEqual(
list(attentions[0].shape[-4:]),
self.assertEqual(out_len, correct_outlen) [self.model_tester.num_attention_heads, encoder_seq_length, chunk_length, encoder_key_length],
)
# decoder attentions else:
decoder_attentions = outputs.decoder_attentions self.assertListEqual(
self.assertIsInstance(decoder_attentions, (list, tuple)) list(attentions[0].shape[-3:]),
self.assertEqual(len(decoder_attentions), self.model_tester.num_hidden_layers) [self.model_tester.num_attention_heads, encoder_seq_length, encoder_key_length],
self.assertListEqual( )
list(decoder_attentions[0].shape[-3:]), out_len = len(outputs)
[self.model_tester.num_attention_heads, decoder_seq_length, decoder_key_length],
) if self.is_encoder_decoder:
correct_outlen = 5
# loss is at first position
if "labels" in inputs_dict:
correct_outlen += 1 # loss is added to beginning
# Question Answering model returns start_logits and end_logits
if model_class in get_values(MODEL_FOR_QUESTION_ANSWERING_MAPPING):
correct_outlen += 1 # start_logits and end_logits instead of only 1 output
if "past_key_values" in outputs:
correct_outlen += 1 # past_key_values have been returned
self.assertEqual(out_len, correct_outlen)
# decoder attentions
decoder_attentions = outputs.decoder_attentions
self.assertIsInstance(decoder_attentions, (list, tuple))
self.assertEqual(len(decoder_attentions), self.model_tester.num_hidden_layers)
self.assertListEqual(
list(decoder_attentions[0].shape[-3:]),
[self.model_tester.num_attention_heads, decoder_seq_length, decoder_key_length],
)
# cross attentions # cross attentions
cross_attentions = outputs.cross_attentions cross_attentions = outputs.cross_attentions
self.assertIsInstance(cross_attentions, (list, tuple)) self.assertIsInstance(cross_attentions, (list, tuple))
self.assertEqual(len(cross_attentions), self.model_tester.num_hidden_layers) self.assertEqual(len(cross_attentions), self.model_tester.num_hidden_layers)
self.assertListEqual( self.assertListEqual(
list(cross_attentions[0].shape[-3:]), list(cross_attentions[0].shape[-3:]),
[ [
self.model_tester.num_attention_heads, self.model_tester.num_attention_heads,
decoder_seq_length, decoder_seq_length,
encoder_key_length, encoder_key_length,
], ],
) )
# Check attention is always last and order is fine # Check attention is always last and order is fine
inputs_dict["output_attentions"] = True inputs_dict["output_attentions"] = True
inputs_dict["output_hidden_states"] = True inputs_dict["output_hidden_states"] = True
model = model_class(config) model = model_class(config)
model.to(torch_device) model.to(torch_device)
model.eval() model.eval()
with torch.no_grad(): with torch.no_grad():
outputs = model(**self._prepare_for_class(inputs_dict, model_class)) outputs = model(**self._prepare_for_class(inputs_dict, model_class))
if hasattr(self.model_tester, "num_hidden_states_types"): if hasattr(self.model_tester, "num_hidden_states_types"):
added_hidden_states = self.model_tester.num_hidden_states_types added_hidden_states = self.model_tester.num_hidden_states_types
elif self.is_encoder_decoder: elif self.is_encoder_decoder:
added_hidden_states = 2 added_hidden_states = 2
else: else:
added_hidden_states = 1 added_hidden_states = 1
self.assertEqual(out_len + added_hidden_states, len(outputs)) self.assertEqual(out_len + added_hidden_states, len(outputs))
self_attentions = outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions self_attentions = outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions
self.assertEqual(len(self_attentions), self.model_tester.num_hidden_layers) self.assertEqual(len(self_attentions), self.model_tester.num_hidden_layers)
if chunk_length is not None: if chunk_length is not None:
self.assertListEqual( self.assertListEqual(
list(self_attentions[0].shape[-4:]), list(self_attentions[0].shape[-4:]),
[self.model_tester.num_attention_heads, encoder_seq_length, chunk_length, encoder_key_length], [self.model_tester.num_attention_heads, encoder_seq_length, chunk_length, encoder_key_length],
) )
else: else:
self.assertListEqual( self.assertListEqual(
list(self_attentions[0].shape[-3:]), list(self_attentions[0].shape[-3:]),
[self.model_tester.num_attention_heads, encoder_seq_length, encoder_key_length], [self.model_tester.num_attention_heads, encoder_seq_length, encoder_key_length],
) )
@slow @slow
def test_torchscript(self): def test_torchscript(self):
...@@ -1040,7 +1045,7 @@ class ModelTesterMixin: ...@@ -1040,7 +1045,7 @@ class ModelTesterMixin:
def test_retain_grad_hidden_states_attentions(self): def test_retain_grad_hidden_states_attentions(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
config.output_hidden_states = True config.output_hidden_states = True
config.output_attentions = True config.output_attentions = self.has_attentions
# no need to test all models as different heads yield the same functionality # no need to test all models as different heads yield the same functionality
model_class = self.all_model_classes[0] model_class = self.all_model_classes[0]
...@@ -1056,37 +1061,45 @@ class ModelTesterMixin: ...@@ -1056,37 +1061,45 @@ class ModelTesterMixin:
if config.is_encoder_decoder: if config.is_encoder_decoder:
# Seq2Seq models # Seq2Seq models
encoder_hidden_states = outputs.encoder_hidden_states[0] encoder_hidden_states = outputs.encoder_hidden_states[0]
encoder_attentions = outputs.encoder_attentions[0]
encoder_hidden_states.retain_grad() encoder_hidden_states.retain_grad()
encoder_attentions.retain_grad()
decoder_hidden_states = outputs.decoder_hidden_states[0] decoder_hidden_states = outputs.decoder_hidden_states[0]
decoder_attentions = outputs.decoder_attentions[0]
decoder_hidden_states.retain_grad() decoder_hidden_states.retain_grad()
decoder_attentions.retain_grad()
cross_attentions = outputs.cross_attentions[0] if self.has_attentions:
cross_attentions.retain_grad() encoder_attentions = outputs.encoder_attentions[0]
encoder_attentions.retain_grad()
decoder_attentions = outputs.decoder_attentions[0]
decoder_attentions.retain_grad()
cross_attentions = outputs.cross_attentions[0]
cross_attentions.retain_grad()
output.flatten()[0].backward(retain_graph=True) output.flatten()[0].backward(retain_graph=True)
self.assertIsNotNone(encoder_hidden_states.grad) self.assertIsNotNone(encoder_hidden_states.grad)
self.assertIsNotNone(encoder_attentions.grad)
self.assertIsNotNone(decoder_hidden_states.grad) self.assertIsNotNone(decoder_hidden_states.grad)
self.assertIsNotNone(decoder_attentions.grad)
self.assertIsNotNone(cross_attentions.grad) if self.has_attentions:
self.assertIsNotNone(encoder_attentions.grad)
self.assertIsNotNone(decoder_attentions.grad)
self.assertIsNotNone(cross_attentions.grad)
else: else:
# Encoder-/Decoder-only models # Encoder-/Decoder-only models
hidden_states = outputs.hidden_states[0] hidden_states = outputs.hidden_states[0]
attentions = outputs.attentions[0]
hidden_states.retain_grad() hidden_states.retain_grad()
attentions.retain_grad()
if self.has_attentions:
attentions = outputs.attentions[0]
attentions.retain_grad()
output.flatten()[0].backward(retain_graph=True) output.flatten()[0].backward(retain_graph=True)
self.assertIsNotNone(hidden_states.grad) self.assertIsNotNone(hidden_states.grad)
self.assertIsNotNone(attentions.grad)
if self.has_attentions:
self.assertIsNotNone(attentions.grad)
def test_feed_forward_chunking(self): def test_feed_forward_chunking(self):
( (
...@@ -1424,23 +1437,24 @@ class ModelTesterMixin: ...@@ -1424,23 +1437,24 @@ class ModelTesterMixin:
dict_inputs = self._prepare_for_class(inputs_dict, model_class) dict_inputs = self._prepare_for_class(inputs_dict, model_class)
check_equivalence(model, tuple_inputs, dict_inputs, {"output_hidden_states": True}) check_equivalence(model, tuple_inputs, dict_inputs, {"output_hidden_states": True})
tuple_inputs = self._prepare_for_class(inputs_dict, model_class)
dict_inputs = self._prepare_for_class(inputs_dict, model_class)
check_equivalence(model, tuple_inputs, dict_inputs, {"output_attentions": True})
tuple_inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True) tuple_inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
dict_inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True) dict_inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
check_equivalence(model, tuple_inputs, dict_inputs, {"output_hidden_states": True}) check_equivalence(model, tuple_inputs, dict_inputs, {"output_hidden_states": True})
tuple_inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True) if self.has_attentions:
dict_inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True) tuple_inputs = self._prepare_for_class(inputs_dict, model_class)
check_equivalence(model, tuple_inputs, dict_inputs, {"output_attentions": True}) dict_inputs = self._prepare_for_class(inputs_dict, model_class)
check_equivalence(model, tuple_inputs, dict_inputs, {"output_attentions": True})
tuple_inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True) tuple_inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
dict_inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True) dict_inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
check_equivalence( check_equivalence(model, tuple_inputs, dict_inputs, {"output_attentions": True})
model, tuple_inputs, dict_inputs, {"output_hidden_states": True, "output_attentions": True}
) tuple_inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
dict_inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
check_equivalence(
model, tuple_inputs, dict_inputs, {"output_hidden_states": True, "output_attentions": True}
)
@is_pt_tf_cross_test @is_pt_tf_cross_test
def test_pt_tf_model_equivalence(self): def test_pt_tf_model_equivalence(self):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment