"git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "e3e16ddc3c22b9bc49ea19b616bc3eec58d6cc9c"
Unverified Commit 9932ee4b authored by Francesco Saverio Zuppichini's avatar Francesco Saverio Zuppichini Committed by GitHub
Browse files

made MaskFormerModelTest faster (#15942)

parent e8efaecb
......@@ -20,7 +20,7 @@ import unittest
import numpy as np
from tests.test_modeling_common import floats_tensor
from transformers import MaskFormerConfig, is_torch_available, is_vision_available
from transformers import DetrConfig, MaskFormerConfig, SwinConfig, is_torch_available, is_vision_available
from transformers.file_utils import cached_property
from transformers.testing_utils import require_torch, require_vision, slow, torch_device
......@@ -47,12 +47,12 @@ class MaskFormerModelTester:
batch_size=2,
is_training=True,
use_auxiliary_loss=False,
num_queries=100,
num_queries=10,
num_channels=3,
min_size=384,
max_size=640,
num_labels=150,
mask_feature_size=256,
min_size=32 * 4,
max_size=32 * 6,
num_labels=4,
mask_feature_size=32,
):
self.parent = parent
self.batch_size = batch_size
......@@ -79,11 +79,20 @@ class MaskFormerModelTester:
return config, pixel_values, pixel_mask, mask_labels, class_labels
def get_config(self):
return MaskFormerConfig(
num_queries=self.num_queries,
return MaskFormerConfig.from_backbone_and_decoder_configs(
backbone_config=SwinConfig(
depths=[1, 1, 1, 1],
),
decoder_config=DetrConfig(
decoder_ffn_dim=128,
num_queries=self.num_queries,
decoder_attention_heads=2,
d_model=self.mask_feature_size,
),
mask_feature_size=self.mask_feature_size,
fpn_feature_size=self.mask_feature_size,
num_channels=self.num_channels,
num_labels=self.num_labels,
mask_feature_size=self.mask_feature_size,
)
def prepare_config_and_inputs_for_common(self):
......@@ -161,7 +170,6 @@ class MaskFormerModelTester:
@require_torch
@slow
class MaskFormerModelTest(ModelTesterMixin, unittest.TestCase):
all_model_classes = (MaskFormerModel, MaskFormerForInstanceSegmentation) if is_torch_available() else ()
......@@ -221,11 +229,11 @@ class MaskFormerModelTest(ModelTesterMixin, unittest.TestCase):
model = MaskFormerModel.from_pretrained(model_name)
self.assertIsNotNone(model)
@slow
def test_model_with_labels(self):
size = (self.model_tester.min_size,) * 2
inputs = {
"pixel_values": torch.randn((2, 3, 384, 384)),
"mask_labels": torch.randn((2, 10, 384, 384)),
"pixel_values": torch.randn((2, 3, *size)),
"mask_labels": torch.randn((2, 10, *size)),
"class_labels": torch.zeros(2, 10).long(),
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment