Unverified Commit dfc76b25 authored by amyeroberts's avatar amyeroberts Committed by GitHub
Browse files

has_attentions - consistent test skipping logic and tf tests (#17495)

parent 66e86567
...@@ -158,6 +158,10 @@ class ConvNextModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -158,6 +158,10 @@ class ConvNextModelTest(ModelTesterMixin, unittest.TestCase):
def create_and_test_config_common_properties(self): def create_and_test_config_common_properties(self):
return return
@unittest.skip(reason="ConvNext does not output attentions")
def test_attention_outputs(self):
pass
@unittest.skip(reason="ConvNext does not use inputs_embeds") @unittest.skip(reason="ConvNext does not use inputs_embeds")
def test_inputs_embeds(self): def test_inputs_embeds(self):
pass pass
......
...@@ -173,6 +173,10 @@ class CvtModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -173,6 +173,10 @@ class CvtModelTest(ModelTesterMixin, unittest.TestCase):
def create_and_test_config_common_properties(self): def create_and_test_config_common_properties(self):
return return
@unittest.skip(reason="Cvt does not output attentions")
def test_attention_outputs(self):
pass
@unittest.skip(reason="Cvt does not use inputs_embeds") @unittest.skip(reason="Cvt does not use inputs_embeds")
def test_inputs_embeds(self): def test_inputs_embeds(self):
pass pass
......
...@@ -695,6 +695,10 @@ class FlavaImageCodebookTest(ModelTesterMixin, unittest.TestCase): ...@@ -695,6 +695,10 @@ class FlavaImageCodebookTest(ModelTesterMixin, unittest.TestCase):
expected_arg_names = ["pixel_values"] expected_arg_names = ["pixel_values"]
self.assertListEqual(arg_names[:1], expected_arg_names) self.assertListEqual(arg_names[:1], expected_arg_names)
@unittest.skip(reason="Flava does not output attentions")
def test_attention_outputs(self):
pass
def test_model_common_attributes(self): def test_model_common_attributes(self):
# No embedding in multimodal model # No embedding in multimodal model
pass pass
......
...@@ -142,6 +142,10 @@ class PoolFormerModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -142,6 +142,10 @@ class PoolFormerModelTest(ModelTesterMixin, unittest.TestCase):
config_and_inputs = self.model_tester.prepare_config_and_inputs() config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_model(*config_and_inputs) self.model_tester.create_and_check_model(*config_and_inputs)
@unittest.skip(reason="PoolFormer does not output attentions")
def test_attention_outputs(self):
pass
@unittest.skip("PoolFormer does not use inputs_embeds") @unittest.skip("PoolFormer does not use inputs_embeds")
def test_inputs_embeds(self): def test_inputs_embeds(self):
pass pass
......
...@@ -147,6 +147,10 @@ class RegNetModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -147,6 +147,10 @@ class RegNetModelTest(ModelTesterMixin, unittest.TestCase):
def create_and_test_config_common_properties(self): def create_and_test_config_common_properties(self):
return return
@unittest.skip(reason="RegNet does not output attentions")
def test_attention_outputs(self):
pass
@unittest.skip(reason="RegNet does not use inputs_embeds") @unittest.skip(reason="RegNet does not use inputs_embeds")
def test_inputs_embeds(self): def test_inputs_embeds(self):
pass pass
......
...@@ -147,6 +147,10 @@ class ResNetModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -147,6 +147,10 @@ class ResNetModelTest(ModelTesterMixin, unittest.TestCase):
def create_and_test_config_common_properties(self): def create_and_test_config_common_properties(self):
return return
@unittest.skip(reason="ResNet does not output attentions")
def test_attention_outputs(self):
pass
@unittest.skip(reason="ResNet does not use inputs_embeds") @unittest.skip(reason="ResNet does not use inputs_embeds")
def test_inputs_embeds(self): def test_inputs_embeds(self):
pass pass
......
...@@ -144,6 +144,10 @@ class VanModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -144,6 +144,10 @@ class VanModelTest(ModelTesterMixin, unittest.TestCase):
def create_and_test_config_common_properties(self): def create_and_test_config_common_properties(self):
return return
@unittest.skip(reason="Van does not output attentions")
def test_attention_outputs(self):
pass
@unittest.skip(reason="Van does not use inputs_embeds") @unittest.skip(reason="Van does not use inputs_embeds")
def test_inputs_embeds(self): def test_inputs_embeds(self):
pass pass
......
...@@ -485,10 +485,6 @@ class ModelTesterMixin: ...@@ -485,10 +485,6 @@ class ModelTesterMixin:
loss.backward() loss.backward()
def test_attention_outputs(self): def test_attention_outputs(self):
if not self.has_attentions:
pass
else:
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
config.return_dict = True config.return_dict = True
......
...@@ -978,6 +978,7 @@ class TFModelTesterMixin: ...@@ -978,6 +978,7 @@ class TFModelTesterMixin:
dict_inputs = self._prepare_for_class(inputs_dict, model_class) dict_inputs = self._prepare_for_class(inputs_dict, model_class)
check_equivalence(model, tuple_inputs, dict_inputs, {"output_hidden_states": True}) check_equivalence(model, tuple_inputs, dict_inputs, {"output_hidden_states": True})
if self.has_attentions:
tuple_inputs = self._prepare_for_class(inputs_dict, model_class) tuple_inputs = self._prepare_for_class(inputs_dict, model_class)
dict_inputs = self._prepare_for_class(inputs_dict, model_class) dict_inputs = self._prepare_for_class(inputs_dict, model_class)
check_equivalence(model, tuple_inputs, dict_inputs, {"output_attentions": True}) check_equivalence(model, tuple_inputs, dict_inputs, {"output_attentions": True})
...@@ -992,6 +993,7 @@ class TFModelTesterMixin: ...@@ -992,6 +993,7 @@ class TFModelTesterMixin:
dict_inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True) dict_inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
check_equivalence(model, tuple_inputs, dict_inputs, {"output_hidden_states": True}) check_equivalence(model, tuple_inputs, dict_inputs, {"output_hidden_states": True})
if self.has_attentions:
tuple_inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True) tuple_inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
dict_inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True) dict_inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
check_equivalence(model, tuple_inputs, dict_inputs, {"output_attentions": True}) check_equivalence(model, tuple_inputs, dict_inputs, {"output_attentions": True})
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment