"vscode:/vscode.git/clone" did not exist on "ca2047bc352e32f8d6dc26f4e55c2556149230d9"
Unverified Commit 10237054 authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

Check models used for common tests are small (#24824)

* First models

* Conditional DETR

* Treat DETR models, skip others

* Skip LayoutLMv2 as well

* Fix last tests
parent a865b62e
...@@ -279,6 +279,10 @@ class ViTMAEModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase): ...@@ -279,6 +279,10 @@ class ViTMAEModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
def test_model_outputs_equivalence(self): def test_model_outputs_equivalence(self):
pass pass
@unittest.skip("Will be fixed soon by reducing the size of the model used for common tests.")
def test_model_is_small(self):
pass
@slow @slow
def test_model_from_pretrained(self): def test_model_from_pretrained(self):
for model_name in VIT_PRETRAINED_MODEL_ARCHIVE_LIST[:1]: for model_name in VIT_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
......
...@@ -310,6 +310,10 @@ class VivitModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -310,6 +310,10 @@ class VivitModelTest(ModelTesterMixin, unittest.TestCase):
check_hidden_states_output(inputs_dict, config, model_class) check_hidden_states_output(inputs_dict, config, model_class)
@unittest.skip("Will be fixed soon by reducing the size of the model used for common tests.")
def test_model_is_small(self):
pass
# We will verify our results on a video of eating spaghetti # We will verify our results on a video of eating spaghetti
# Frame indices used: [164 168 172 176 181 185 189 193 198 202 206 210 215 219 223 227] # Frame indices used: [164 168 172 176 181 185 189 193 198 202 206 210 215 219 223 227]
......
...@@ -2705,6 +2705,18 @@ class ModelTesterMixin: ...@@ -2705,6 +2705,18 @@ class ModelTesterMixin:
else: else:
new_model_without_prefix(input_ids) new_model_without_prefix(input_ids)
def test_model_is_small(self):
# Just a consistency check to make sure we are not running tests on 80M parameter models.
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
# print(config)
for model_class in self.all_model_classes:
model = model_class(config)
num_params = model.num_parameters()
assert (
num_params < 1000000
), f"{model_class} is too big for the common tests ({num_params})! It should have 200k max."
global_rng = random.Random() global_rng = random.Random()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment