Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
9932ee4b
"git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "e3e16ddc3c22b9bc49ea19b616bc3eec58d6cc9c"
Unverified
Commit
9932ee4b
authored
Mar 04, 2022
by
Francesco Saverio Zuppichini
Committed by
GitHub
Mar 04, 2022
Browse files
made MaskFormerModelTest faster (#15942)
parent
e8efaecb
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
21 additions
and
13 deletions
+21
-13
tests/maskformer/test_modeling_maskformer.py
tests/maskformer/test_modeling_maskformer.py
+21
-13
No files found.
tests/maskformer/test_modeling_maskformer.py
View file @
9932ee4b
...
...
@@ -20,7 +20,7 @@ import unittest
import
numpy
as
np
from
tests.test_modeling_common
import
floats_tensor
from
transformers
import
MaskFormer
Config
,
is_torch_available
,
is_vision_available
from
transformers
import
DetrConfig
,
MaskFormerConfig
,
Swin
Config
,
is_torch_available
,
is_vision_available
from
transformers.file_utils
import
cached_property
from
transformers.testing_utils
import
require_torch
,
require_vision
,
slow
,
torch_device
...
...
@@ -47,12 +47,12 @@ class MaskFormerModelTester:
batch_size
=
2
,
is_training
=
True
,
use_auxiliary_loss
=
False
,
num_queries
=
10
0
,
num_queries
=
10
,
num_channels
=
3
,
min_size
=
3
8
4
,
max_size
=
640
,
num_labels
=
150
,
mask_feature_size
=
2
56
,
min_size
=
3
2
*
4
,
max_size
=
32
*
6
,
num_labels
=
4
,
mask_feature_size
=
3
2
,
):
self
.
parent
=
parent
self
.
batch_size
=
batch_size
...
...
@@ -79,11 +79,20 @@ class MaskFormerModelTester:
return
config
,
pixel_values
,
pixel_mask
,
mask_labels
,
class_labels
def
get_config
(
self
):
return
MaskFormerConfig
(
num_queries
=
self
.
num_queries
,
return
MaskFormerConfig
.
from_backbone_and_decoder_configs
(
backbone_config
=
SwinConfig
(
depths
=
[
1
,
1
,
1
,
1
],
),
decoder_config
=
DetrConfig
(
decoder_ffn_dim
=
128
,
num_queries
=
self
.
num_queries
,
decoder_attention_heads
=
2
,
d_model
=
self
.
mask_feature_size
,
),
mask_feature_size
=
self
.
mask_feature_size
,
fpn_feature_size
=
self
.
mask_feature_size
,
num_channels
=
self
.
num_channels
,
num_labels
=
self
.
num_labels
,
mask_feature_size
=
self
.
mask_feature_size
,
)
def
prepare_config_and_inputs_for_common
(
self
):
...
...
@@ -161,7 +170,6 @@ class MaskFormerModelTester:
@
require_torch
@
slow
class
MaskFormerModelTest
(
ModelTesterMixin
,
unittest
.
TestCase
):
all_model_classes
=
(
MaskFormerModel
,
MaskFormerForInstanceSegmentation
)
if
is_torch_available
()
else
()
...
...
@@ -221,11 +229,11 @@ class MaskFormerModelTest(ModelTesterMixin, unittest.TestCase):
model
=
MaskFormerModel
.
from_pretrained
(
model_name
)
self
.
assertIsNotNone
(
model
)
@
slow
def
test_model_with_labels
(
self
):
size
=
(
self
.
model_tester
.
min_size
,)
*
2
inputs
=
{
"pixel_values"
:
torch
.
randn
((
2
,
3
,
384
,
384
)),
"mask_labels"
:
torch
.
randn
((
2
,
10
,
384
,
384
)),
"pixel_values"
:
torch
.
randn
((
2
,
3
,
*
size
)),
"mask_labels"
:
torch
.
randn
((
2
,
10
,
*
size
)),
"class_labels"
:
torch
.
zeros
(
2
,
10
).
long
(),
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment