Unverified Commit 9932ee4b authored by Francesco Saverio Zuppichini's avatar Francesco Saverio Zuppichini Committed by GitHub
Browse files

made MaskFormerModelTest faster (#15942)

parent e8efaecb
...@@ -20,7 +20,7 @@ import unittest ...@@ -20,7 +20,7 @@ import unittest
import numpy as np import numpy as np
from tests.test_modeling_common import floats_tensor from tests.test_modeling_common import floats_tensor
from transformers import MaskFormerConfig, is_torch_available, is_vision_available from transformers import DetrConfig, MaskFormerConfig, SwinConfig, is_torch_available, is_vision_available
from transformers.file_utils import cached_property from transformers.file_utils import cached_property
from transformers.testing_utils import require_torch, require_vision, slow, torch_device from transformers.testing_utils import require_torch, require_vision, slow, torch_device
...@@ -47,12 +47,12 @@ class MaskFormerModelTester: ...@@ -47,12 +47,12 @@ class MaskFormerModelTester:
batch_size=2, batch_size=2,
is_training=True, is_training=True,
use_auxiliary_loss=False, use_auxiliary_loss=False,
num_queries=100, num_queries=10,
num_channels=3, num_channels=3,
min_size=384, min_size=32 * 4,
max_size=640, max_size=32 * 6,
num_labels=150, num_labels=4,
mask_feature_size=256, mask_feature_size=32,
): ):
self.parent = parent self.parent = parent
self.batch_size = batch_size self.batch_size = batch_size
...@@ -79,11 +79,20 @@ class MaskFormerModelTester: ...@@ -79,11 +79,20 @@ class MaskFormerModelTester:
return config, pixel_values, pixel_mask, mask_labels, class_labels return config, pixel_values, pixel_mask, mask_labels, class_labels
def get_config(self): def get_config(self):
return MaskFormerConfig( return MaskFormerConfig.from_backbone_and_decoder_configs(
num_queries=self.num_queries, backbone_config=SwinConfig(
depths=[1, 1, 1, 1],
),
decoder_config=DetrConfig(
decoder_ffn_dim=128,
num_queries=self.num_queries,
decoder_attention_heads=2,
d_model=self.mask_feature_size,
),
mask_feature_size=self.mask_feature_size,
fpn_feature_size=self.mask_feature_size,
num_channels=self.num_channels, num_channels=self.num_channels,
num_labels=self.num_labels, num_labels=self.num_labels,
mask_feature_size=self.mask_feature_size,
) )
def prepare_config_and_inputs_for_common(self): def prepare_config_and_inputs_for_common(self):
...@@ -161,7 +170,6 @@ class MaskFormerModelTester: ...@@ -161,7 +170,6 @@ class MaskFormerModelTester:
@require_torch @require_torch
@slow
class MaskFormerModelTest(ModelTesterMixin, unittest.TestCase): class MaskFormerModelTest(ModelTesterMixin, unittest.TestCase):
all_model_classes = (MaskFormerModel, MaskFormerForInstanceSegmentation) if is_torch_available() else () all_model_classes = (MaskFormerModel, MaskFormerForInstanceSegmentation) if is_torch_available() else ()
...@@ -221,11 +229,11 @@ class MaskFormerModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -221,11 +229,11 @@ class MaskFormerModelTest(ModelTesterMixin, unittest.TestCase):
model = MaskFormerModel.from_pretrained(model_name) model = MaskFormerModel.from_pretrained(model_name)
self.assertIsNotNone(model) self.assertIsNotNone(model)
@slow
def test_model_with_labels(self): def test_model_with_labels(self):
size = (self.model_tester.min_size,) * 2
inputs = { inputs = {
"pixel_values": torch.randn((2, 3, 384, 384)), "pixel_values": torch.randn((2, 3, *size)),
"mask_labels": torch.randn((2, 10, 384, 384)), "mask_labels": torch.randn((2, 10, *size)),
"class_labels": torch.zeros(2, 10).long(), "class_labels": torch.zeros(2, 10).long(),
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment