Unverified Commit 59499bbe authored by NielsRogge's avatar NielsRogge Committed by GitHub
Browse files

Update forward signature test for vision models (#27681)

* Update forward signature

* Empty-Commit
parent 1d7f406e
...@@ -15,7 +15,6 @@ ...@@ -15,7 +15,6 @@
""" Testing suite for the PyTorch ViT model. """ """ Testing suite for the PyTorch ViT model. """
import inspect
import unittest import unittest
from transformers import ViTConfig from transformers import ViTConfig
...@@ -224,18 +223,6 @@ class ViTModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase): ...@@ -224,18 +223,6 @@ class ViTModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
x = model.get_output_embeddings() x = model.get_output_embeddings()
self.assertTrue(x is None or isinstance(x, nn.Linear)) self.assertTrue(x is None or isinstance(x, nn.Linear))
def test_forward_signature(self):
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
model = model_class(config)
signature = inspect.signature(model.forward)
# signature.parameters is an OrderedDict => so arg_names order is deterministic
arg_names = [*signature.parameters.keys()]
expected_arg_names = ["pixel_values"]
self.assertListEqual(arg_names[:1], expected_arg_names)
def test_model(self): def test_model(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs() config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_model(*config_and_inputs) self.model_tester.create_and_check_model(*config_and_inputs)
......
...@@ -15,7 +15,6 @@ ...@@ -15,7 +15,6 @@
""" Testing suite for the PyTorch ViT Hybrid model. """ """ Testing suite for the PyTorch ViT Hybrid model. """
import inspect
import unittest import unittest
from transformers import ViTHybridConfig from transformers import ViTHybridConfig
...@@ -185,18 +184,6 @@ class ViTHybridModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCas ...@@ -185,18 +184,6 @@ class ViTHybridModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCas
x = model.get_output_embeddings() x = model.get_output_embeddings()
self.assertTrue(x is None or isinstance(x, nn.Linear)) self.assertTrue(x is None or isinstance(x, nn.Linear))
def test_forward_signature(self):
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
model = model_class(config)
signature = inspect.signature(model.forward)
# signature.parameters is an OrderedDict => so arg_names order is deterministic
arg_names = [*signature.parameters.keys()]
expected_arg_names = ["pixel_values"]
self.assertListEqual(arg_names[:1], expected_arg_names)
def test_model(self): def test_model(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs() config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_model(*config_and_inputs) self.model_tester.create_and_check_model(*config_and_inputs)
......
...@@ -15,7 +15,6 @@ ...@@ -15,7 +15,6 @@
""" Testing suite for the PyTorch ViTMAE model. """ """ Testing suite for the PyTorch ViTMAE model. """
import inspect
import math import math
import tempfile import tempfile
import unittest import unittest
...@@ -192,18 +191,6 @@ class ViTMAEModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase): ...@@ -192,18 +191,6 @@ class ViTMAEModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
x = model.get_output_embeddings() x = model.get_output_embeddings()
self.assertTrue(x is None or isinstance(x, nn.Linear)) self.assertTrue(x is None or isinstance(x, nn.Linear))
def test_forward_signature(self):
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
model = model_class(config)
signature = inspect.signature(model.forward)
# signature.parameters is an OrderedDict => so arg_names order is deterministic
arg_names = [*signature.parameters.keys()]
expected_arg_names = ["pixel_values"]
self.assertListEqual(arg_names[:1], expected_arg_names)
def test_model(self): def test_model(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs() config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_model(*config_and_inputs) self.model_tester.create_and_check_model(*config_and_inputs)
......
...@@ -15,7 +15,6 @@ ...@@ -15,7 +15,6 @@
""" Testing suite for the PyTorch ViTMSN model. """ """ Testing suite for the PyTorch ViTMSN model. """
import inspect
import unittest import unittest
from transformers import ViTMSNConfig from transformers import ViTMSNConfig
...@@ -183,18 +182,6 @@ class ViTMSNModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase): ...@@ -183,18 +182,6 @@ class ViTMSNModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
x = model.get_output_embeddings() x = model.get_output_embeddings()
self.assertTrue(x is None or isinstance(x, nn.Linear)) self.assertTrue(x is None or isinstance(x, nn.Linear))
def test_forward_signature(self):
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
model = model_class(config)
signature = inspect.signature(model.forward)
# signature.parameters is an OrderedDict => so arg_names order is deterministic
arg_names = [*signature.parameters.keys()]
expected_arg_names = ["pixel_values"]
self.assertListEqual(arg_names[:1], expected_arg_names)
def test_model(self): def test_model(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs() config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_model(*config_and_inputs) self.model_tester.create_and_check_model(*config_and_inputs)
......
...@@ -15,7 +15,6 @@ ...@@ -15,7 +15,6 @@
""" Testing suite for the PyTorch ViTDet model. """ """ Testing suite for the PyTorch ViTDet model. """
import inspect
import unittest import unittest
from transformers import VitDetConfig from transformers import VitDetConfig
...@@ -210,18 +209,6 @@ class VitDetModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase): ...@@ -210,18 +209,6 @@ class VitDetModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
x = model.get_output_embeddings() x = model.get_output_embeddings()
self.assertTrue(x is None or isinstance(x, nn.Linear)) self.assertTrue(x is None or isinstance(x, nn.Linear))
def test_forward_signature(self):
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
model = model_class(config)
signature = inspect.signature(model.forward)
# signature.parameters is an OrderedDict => so arg_names order is deterministic
arg_names = [*signature.parameters.keys()]
expected_arg_names = ["pixel_values"]
self.assertListEqual(arg_names[:1], expected_arg_names)
def test_model(self): def test_model(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs() config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_model(*config_and_inputs) self.model_tester.create_and_check_model(*config_and_inputs)
......
...@@ -15,7 +15,6 @@ ...@@ -15,7 +15,6 @@
""" Testing suite for the PyTorch VitMatte model. """ """ Testing suite for the PyTorch VitMatte model. """
import inspect
import unittest import unittest
from huggingface_hub import hf_hub_download from huggingface_hub import hf_hub_download
...@@ -189,18 +188,6 @@ class VitMatteModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase ...@@ -189,18 +188,6 @@ class VitMatteModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase
def test_model_common_attributes(self): def test_model_common_attributes(self):
pass pass
def test_forward_signature(self):
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
model = model_class(config)
signature = inspect.signature(model.forward)
# signature.parameters is an OrderedDict => so arg_names order is deterministic
arg_names = [*signature.parameters.keys()]
expected_arg_names = ["pixel_values"]
self.assertListEqual(arg_names[:1], expected_arg_names)
def test_model(self): def test_model(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs() config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_model(*config_and_inputs) self.model_tester.create_and_check_model(*config_and_inputs)
......
...@@ -15,7 +15,6 @@ ...@@ -15,7 +15,6 @@
""" Testing suite for the PyTorch YOLOS model. """ """ Testing suite for the PyTorch YOLOS model. """
import inspect
import unittest import unittest
from transformers import YolosConfig from transformers import YolosConfig
...@@ -217,18 +216,6 @@ class YolosModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase): ...@@ -217,18 +216,6 @@ class YolosModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
x = model.get_output_embeddings() x = model.get_output_embeddings()
self.assertTrue(x is None or isinstance(x, nn.Linear)) self.assertTrue(x is None or isinstance(x, nn.Linear))
def test_forward_signature(self):
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
model = model_class(config)
signature = inspect.signature(model.forward)
# signature.parameters is an OrderedDict => so arg_names order is deterministic
arg_names = [*signature.parameters.keys()]
expected_arg_names = ["pixel_values"]
self.assertListEqual(arg_names[:1], expected_arg_names)
def test_model(self): def test_model(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs() config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_model(*config_and_inputs) self.model_tester.create_and_check_model(*config_and_inputs)
......
...@@ -543,7 +543,7 @@ class ModelTesterMixin: ...@@ -543,7 +543,7 @@ class ModelTesterMixin:
) )
self.assertListEqual(arg_names[: len(expected_arg_names)], expected_arg_names) self.assertListEqual(arg_names[: len(expected_arg_names)], expected_arg_names)
else: else:
expected_arg_names = ["input_ids"] expected_arg_names = [model.main_input_name]
self.assertListEqual(arg_names[:1], expected_arg_names) self.assertListEqual(arg_names[:1], expected_arg_names)
def check_training_gradient_checkpointing(self, gradient_checkpointing_kwargs=None): def check_training_gradient_checkpointing(self, gradient_checkpointing_kwargs=None):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment