Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
323f28dc
Unverified
Commit
323f28dc
authored
Nov 01, 2021
by
Nicolas Patry
Committed by
GitHub
Nov 01, 2021
Browse files
Fixing `image-segmentation` tests. (#14223)
parent
7396095a
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
11 additions
and
6 deletions
+11
-6
src/transformers/pipelines/image_segmentation.py
src/transformers/pipelines/image_segmentation.py
+3
-3
tests/test_pipelines_image_segmentation.py
tests/test_pipelines_image_segmentation.py
+8
-3
No files found.
src/transformers/pipelines/image_segmentation.py
View file @
323f28dc
...
@@ -126,13 +126,13 @@ class ImageSegmentationPipeline(Pipeline):
...
@@ -126,13 +126,13 @@ class ImageSegmentationPipeline(Pipeline):
def
_forward
(
self
,
model_inputs
):
def
_forward
(
self
,
model_inputs
):
target_size
=
model_inputs
.
pop
(
"target_size"
)
target_size
=
model_inputs
.
pop
(
"target_size"
)
outputs
=
self
.
model
(
**
model_inputs
)
model_
outputs
=
self
.
model
(
**
model_inputs
)
model_outputs
=
{
"outputs"
:
outputs
,
"target_size"
:
target_size
}
model_outputs
[
"target_size"
]
=
target_size
return
model_outputs
return
model_outputs
def
postprocess
(
self
,
model_outputs
,
threshold
=
0.9
,
mask_threshold
=
0.5
):
def
postprocess
(
self
,
model_outputs
,
threshold
=
0.9
,
mask_threshold
=
0.5
):
raw_annotations
=
self
.
feature_extractor
.
post_process_segmentation
(
raw_annotations
=
self
.
feature_extractor
.
post_process_segmentation
(
model_outputs
[
"outputs"
]
,
model_outputs
[
"target_size"
],
threshold
=
threshold
,
mask_threshold
=
0.5
model_outputs
,
model_outputs
[
"target_size"
],
threshold
=
threshold
,
mask_threshold
=
0.5
)
)
raw_annotation
=
raw_annotations
[
0
]
raw_annotation
=
raw_annotations
[
0
]
...
...
tests/test_pipelines_image_segmentation.py
View file @
323f28dc
...
@@ -51,13 +51,18 @@ else:
...
@@ -51,13 +51,18 @@ else:
@
require_timm
@
require_timm
@
require_torch
@
require_torch
@
is_pipeline_test
@
is_pipeline_test
@
unittest
.
skip
(
"Skip while fixing segmentation pipeline tests"
)
class
ImageSegmentationPipelineTests
(
unittest
.
TestCase
,
metaclass
=
PipelineTestCaseMeta
):
class
ImageSegmentationPipelineTests
(
unittest
.
TestCase
,
metaclass
=
PipelineTestCaseMeta
):
model_mapping
=
MODEL_FOR_IMAGE_SEGMENTATION_MAPPING
model_mapping
=
MODEL_FOR_IMAGE_SEGMENTATION_MAPPING
@
require_datasets
def
get_test_pipeline
(
self
,
model
,
tokenizer
,
feature_extractor
):
def
run_pipeline_test
(
self
,
model
,
tokenizer
,
feature_extractor
):
image_segmenter
=
ImageSegmentationPipeline
(
model
=
model
,
feature_extractor
=
feature_extractor
)
image_segmenter
=
ImageSegmentationPipeline
(
model
=
model
,
feature_extractor
=
feature_extractor
)
return
image_segmenter
,
[
"./tests/fixtures/tests_samples/COCO/000000039769.png"
,
"./tests/fixtures/tests_samples/COCO/000000039769.png"
,
]
@
require_datasets
def
run_pipeline_test
(
self
,
image_segmenter
,
examples
):
outputs
=
image_segmenter
(
"./tests/fixtures/tests_samples/COCO/000000039769.png"
,
threshold
=
0.0
)
outputs
=
image_segmenter
(
"./tests/fixtures/tests_samples/COCO/000000039769.png"
,
threshold
=
0.0
)
self
.
assertEqual
(
outputs
,
[{
"score"
:
ANY
(
float
),
"label"
:
ANY
(
str
),
"mask"
:
ANY
(
str
)}]
*
12
)
self
.
assertEqual
(
outputs
,
[{
"score"
:
ANY
(
float
),
"label"
:
ANY
(
str
),
"mask"
:
ANY
(
str
)}]
*
12
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment