Unverified Commit 42571f6e authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

Make more test models smaller (#25005)

* Make more test models tiny

* Make more test models tiny

* More models

* More models
parent 8f1f0bf5
...@@ -115,6 +115,9 @@ class MobileViTV2ModelTester: ...@@ -115,6 +115,9 @@ class MobileViTV2ModelTester:
width_multiplier=self.width_multiplier, width_multiplier=self.width_multiplier,
ffn_dropout=self.ffn_dropout_prob, ffn_dropout=self.ffn_dropout_prob,
attn_dropout=self.attn_dropout_prob, attn_dropout=self.attn_dropout_prob,
base_attn_unit_dims=[16, 24, 32],
n_attn_blocks=[1, 1, 2],
aspp_out_channels=32,
) )
def create_and_check_model(self, config, pixel_values, labels, pixel_labels): def create_and_check_model(self, config, pixel_values, labels, pixel_labels):
...@@ -225,10 +228,6 @@ class MobileViTV2ModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestC ...@@ -225,10 +228,6 @@ class MobileViTV2ModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestC
def test_multi_gpu_data_parallel_forward(self): def test_multi_gpu_data_parallel_forward(self):
pass pass
@unittest.skip("Will be fixed soon by reducing the size of the model used for common tests.")
def test_model_is_small(self):
pass
def test_forward_signature(self): def test_forward_signature(self):
config, _ = self.model_tester.prepare_config_and_inputs_for_common() config, _ = self.model_tester.prepare_config_and_inputs_for_common()
......
...@@ -2708,7 +2708,7 @@ class ModelTesterMixin: ...@@ -2708,7 +2708,7 @@ class ModelTesterMixin:
def test_model_is_small(self): def test_model_is_small(self):
# Just a consistency check to make sure we are not running tests on 80M parameter models. # Just a consistency check to make sure we are not running tests on 80M parameter models.
config, _ = self.model_tester.prepare_config_and_inputs_for_common() config, _ = self.model_tester.prepare_config_and_inputs_for_common()
# print(config) print(config)
for model_class in self.all_model_classes: for model_class in self.all_model_classes:
model = model_class(config) model = model_class(config)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment