Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
52c9e6af
Unverified
Commit
52c9e6af
authored
Jan 04, 2023
by
Alara Dirik
Committed by
GitHub
Jan 04, 2023
Browse files
Fix bug in segmentation postprocessing (#20198)
* Fix post_process_instance_segmentation * Add test for label fusing
parent
292acd71
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
41 additions
and
13 deletions
+41
-13
src/transformers/models/maskformer/image_processing_maskformer.py
...sformers/models/maskformer/image_processing_maskformer.py
+14
-13
tests/models/maskformer/test_feature_extraction_maskformer.py
...s/models/maskformer/test_feature_extraction_maskformer.py
+27
-0
No files found.
src/transformers/models/maskformer/image_processing_maskformer.py
View file @
52c9e6af
...
@@ -1050,12 +1050,13 @@ class MaskFormerImageProcessor(BaseImageProcessor):
...
@@ -1050,12 +1050,13 @@ class MaskFormerImageProcessor(BaseImageProcessor):
# Get segmentation map and segment information of batch item
# Get segmentation map and segment information of batch item
target_size
=
target_sizes
[
i
]
if
target_sizes
is
not
None
else
None
target_size
=
target_sizes
[
i
]
if
target_sizes
is
not
None
else
None
segmentation
,
segments
=
compute_segments
(
segmentation
,
segments
=
compute_segments
(
mask_probs_item
,
mask_probs
=
mask_probs_item
,
pred_scores_item
,
pred_scores
=
pred_scores_item
,
pred_labels_item
,
pred_labels
=
pred_labels_item
,
mask_threshold
,
mask_threshold
=
mask_threshold
,
overlap_mask_area_threshold
,
overlap_mask_area_threshold
=
overlap_mask_area_threshold
,
target_size
,
label_ids_to_fuse
=
[],
target_size
=
target_size
,
)
)
# Return segmentation map in run-length encoding (RLE) format
# Return segmentation map in run-length encoding (RLE) format
...
@@ -1143,13 +1144,13 @@ class MaskFormerImageProcessor(BaseImageProcessor):
...
@@ -1143,13 +1144,13 @@ class MaskFormerImageProcessor(BaseImageProcessor):
# Get segmentation map and segment information of batch item
# Get segmentation map and segment information of batch item
target_size
=
target_sizes
[
i
]
if
target_sizes
is
not
None
else
None
target_size
=
target_sizes
[
i
]
if
target_sizes
is
not
None
else
None
segmentation
,
segments
=
compute_segments
(
segmentation
,
segments
=
compute_segments
(
mask_probs_item
,
mask_probs
=
mask_probs_item
,
pred_scores_item
,
pred_scores
=
pred_scores_item
,
pred_labels_item
,
pred_labels
=
pred_labels_item
,
mask_threshold
,
mask_threshold
=
mask_threshold
,
overlap_mask_area_threshold
,
overlap_mask_area_threshold
=
overlap_mask_area_threshold
,
label_ids_to_fuse
,
label_ids_to_fuse
=
label_ids_to_fuse
,
target_size
,
target_size
=
target_size
,
)
)
results
.
append
({
"segmentation"
:
segmentation
,
"segments_info"
:
segments
})
results
.
append
({
"segmentation"
:
segmentation
,
"segments_info"
:
segments
})
...
...
tests/models/maskformer/test_feature_extraction_maskformer.py
View file @
52c9e6af
...
@@ -589,3 +589,30 @@ class MaskFormerFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest
...
@@ -589,3 +589,30 @@ class MaskFormerFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest
self
.
assertEqual
(
self
.
assertEqual
(
el
[
"segmentation"
].
shape
,
(
self
.
feature_extract_tester
.
height
,
self
.
feature_extract_tester
.
width
)
el
[
"segmentation"
].
shape
,
(
self
.
feature_extract_tester
.
height
,
self
.
feature_extract_tester
.
width
)
)
)
def
test_post_process_label_fusing
(
self
):
feature_extractor
=
self
.
feature_extraction_class
(
num_labels
=
self
.
feature_extract_tester
.
num_classes
)
outputs
=
self
.
feature_extract_tester
.
get_fake_maskformer_outputs
()
segmentation
=
feature_extractor
.
post_process_panoptic_segmentation
(
outputs
,
threshold
=
0
,
mask_threshold
=
0
,
overlap_mask_area_threshold
=
0
)
unfused_segments
=
[
el
[
"segments_info"
]
for
el
in
segmentation
]
fused_segmentation
=
feature_extractor
.
post_process_panoptic_segmentation
(
outputs
,
threshold
=
0
,
mask_threshold
=
0
,
overlap_mask_area_threshold
=
0
,
label_ids_to_fuse
=
{
1
}
)
fused_segments
=
[
el
[
"segments_info"
]
for
el
in
fused_segmentation
]
for
el_unfused
,
el_fused
in
zip
(
unfused_segments
,
fused_segments
):
if
len
(
el_unfused
)
==
0
:
self
.
assertEqual
(
len
(
el_unfused
),
len
(
el_fused
))
continue
# Get number of segments to be fused
fuse_targets
=
[
1
for
el
in
el_unfused
if
el
[
"label_id"
]
in
{
1
}]
num_to_fuse
=
0
if
len
(
fuse_targets
)
==
0
else
sum
(
fuse_targets
)
-
1
# Expected number of segments after fusing
expected_num_segments
=
max
([
el
[
"id"
]
for
el
in
el_unfused
])
-
num_to_fuse
num_segments_fused
=
max
([
el
[
"id"
]
for
el
in
el_fused
])
self
.
assertEqual
(
num_segments_fused
,
expected_num_segments
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment