Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
cd245780
Unverified
Commit
cd245780
authored
Jan 03, 2023
by
Alara Dirik
Committed by
GitHub
Jan 03, 2023
Browse files
Improve OWL-ViT postprocessing (#20980)
* add post_process_object_detection method * style changes
parent
e901914d
Changes
6
Show whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
87 additions
and
18 deletions
+87
-18
docs/source/en/model_doc/owlvit.mdx
docs/source/en/model_doc/owlvit.mdx
+1
-1
src/transformers/models/owlvit/image_processing_owlvit.py
src/transformers/models/owlvit/image_processing_owlvit.py
+63
-1
src/transformers/models/owlvit/modeling_owlvit.py
src/transformers/models/owlvit/modeling_owlvit.py
+11
-11
src/transformers/models/owlvit/processing_owlvit.py
src/transformers/models/owlvit/processing_owlvit.py
+7
-0
src/transformers/pipelines/zero_shot_object_detection.py
src/transformers/pipelines/zero_shot_object_detection.py
+3
-4
tests/pipelines/test_pipelines_zero_shot_object_detection.py
tests/pipelines/test_pipelines_zero_shot_object_detection.py
+2
-1
No files found.
docs/source/en/model_doc/owlvit.mdx
View file @
cd245780
...
...
@@ -80,7 +80,7 @@ This model was contributed by [adirik](https://huggingface.co/adirik). The origi
[[autodoc]] OwlViTImageProcessor
- preprocess
- post_process
- post_process
_object_detection
- post_process_image_guided_detection
## OwlViTFeatureExtractor
...
...
src/transformers/models/owlvit/image_processing_owlvit.py
View file @
cd245780
...
...
@@ -14,7 +14,8 @@
# limitations under the License.
"""Image processor class for OwlViT"""
from
typing
import
Dict
,
List
,
Optional
,
Union
import
warnings
from
typing
import
Dict
,
List
,
Optional
,
Tuple
,
Union
import
numpy
as
np
...
...
@@ -344,6 +345,12 @@ class OwlViTImageProcessor(BaseImageProcessor):
in the batch as predicted by the model.
"""
# TODO: (amy) add support for other frameworks
warnings
.
warn
(
"`post_process` is deprecated and will be removed in v5 of Transformers, please use"
" `post_process_object_detection`"
,
FutureWarning
,
)
logits
,
boxes
=
outputs
.
logits
,
outputs
.
pred_boxes
if
len
(
logits
)
!=
len
(
target_sizes
):
...
...
@@ -367,6 +374,61 @@ class OwlViTImageProcessor(BaseImageProcessor):
return
results
def
post_process_object_detection
(
self
,
outputs
,
threshold
:
float
=
0.1
,
target_sizes
:
Union
[
TensorType
,
List
[
Tuple
]]
=
None
):
"""
Converts the raw output of [`OwlViTForObjectDetection`] into final bounding boxes in (top_left_x, top_left_y,
bottom_right_x, bottom_right_y) format.
Args:
outputs ([`OwlViTObjectDetectionOutput`]):
Raw outputs of the model.
threshold (`float`, *optional*):
Score threshold to keep object detection predictions.
target_sizes (`torch.Tensor` or `List[Tuple[int, int]]`, *optional*):
Tensor of shape `(batch_size, 2)` or list of tuples (`Tuple[int, int]`) containing the target size
`(height, width)` of each image in the batch. If unset, predictions will not be resized.
Returns:
`List[Dict]`: A list of dictionaries, each dictionary containing the scores, labels and boxes for an image
in the batch as predicted by the model.
"""
# TODO: (amy) add support for other frameworks
logits
,
boxes
=
outputs
.
logits
,
outputs
.
pred_boxes
if
target_sizes
is
not
None
:
if
len
(
logits
)
!=
len
(
target_sizes
):
raise
ValueError
(
"Make sure that you pass in as many target sizes as the batch dimension of the logits"
)
probs
=
torch
.
max
(
logits
,
dim
=-
1
)
scores
=
torch
.
sigmoid
(
probs
.
values
)
labels
=
probs
.
indices
# Convert to [x0, y0, x1, y1] format
boxes
=
center_to_corners_format
(
boxes
)
# Convert from relative [0, 1] to absolute [0, height] coordinates
if
target_sizes
is
not
None
:
if
isinstance
(
target_sizes
,
List
):
img_h
=
torch
.
Tensor
([
i
[
0
]
for
i
in
target_sizes
])
img_w
=
torch
.
Tensor
([
i
[
1
]
for
i
in
target_sizes
])
else
:
img_h
,
img_w
=
target_sizes
.
unbind
(
1
)
scale_fct
=
torch
.
stack
([
img_w
,
img_h
,
img_w
,
img_h
],
dim
=
1
)
boxes
=
boxes
*
scale_fct
[:,
None
,
:]
results
=
[]
for
s
,
l
,
b
in
zip
(
scores
,
labels
,
boxes
):
score
=
s
[
s
>
threshold
]
label
=
l
[
s
>
threshold
]
box
=
b
[
s
>
threshold
]
results
.
append
({
"scores"
:
score
,
"labels"
:
label
,
"boxes"
:
box
})
return
results
# TODO: (Amy) Make compatible with other frameworks
def
post_process_image_guided_detection
(
self
,
outputs
,
threshold
=
0.6
,
nms_threshold
=
0.3
,
target_sizes
=
None
):
"""
...
...
src/transformers/models/owlvit/modeling_owlvit.py
View file @
cd245780
...
...
@@ -204,8 +204,8 @@ class OwlViTObjectDetectionOutput(ModelOutput):
pred_boxes (`torch.FloatTensor` of shape `(batch_size, num_patches, 4)`):
Normalized boxes coordinates for all queries, represented as (center_x, center_y, width, height). These
values are normalized in [0, 1], relative to the size of each individual image in the batch (disregarding
possible padding). You can use [`~OwlViTFeatureExtractor.post_process`] to retrieve the
unnormalized
bounding boxes.
possible padding). You can use [`~OwlViTFeatureExtractor.post_process
_object_detection
`] to retrieve the
unnormalized
bounding boxes.
text_embeds (`torch.FloatTensor` of shape `(batch_size, num_max_text_queries, output_dim`):
The text embeddings obtained by applying the projection layer to the pooled output of [`OwlViTTextModel`].
image_embeds (`torch.FloatTensor` of shape `(batch_size, patch_size, patch_size, output_dim`):
...
...
@@ -248,13 +248,13 @@ class OwlViTImageGuidedObjectDetectionOutput(ModelOutput):
target_pred_boxes (`torch.FloatTensor` of shape `(batch_size, num_patches, 4)`):
Normalized boxes coordinates for all queries, represented as (center_x, center_y, width, height). These
values are normalized in [0, 1], relative to the size of each individual target image in the batch
(disregarding possible padding). You can use [`~OwlViTFeatureExtractor.post_process
`] to retrieve the
unnormalized bounding boxes.
(disregarding possible padding). You can use [`~OwlViTFeatureExtractor.post_process
_object_detection`] to
retrieve the
unnormalized bounding boxes.
query_pred_boxes (`torch.FloatTensor` of shape `(batch_size, num_patches, 4)`):
Normalized boxes coordinates for all queries, represented as (center_x, center_y, width, height). These
values are normalized in [0, 1], relative to the size of each individual query image in the batch
(disregarding possible padding). You can use [`~OwlViTFeatureExtractor.post_process
`] to retrieve the
unnormalized bounding boxes.
(disregarding possible padding). You can use [`~OwlViTFeatureExtractor.post_process
_object_detection`] to
retrieve the
unnormalized bounding boxes.
image_embeds (`torch.FloatTensor` of shape `(batch_size, patch_size, patch_size, output_dim`):
Pooled output of [`OwlViTVisionModel`]. OWL-ViT represents images as a set of image patches and computes
image embeddings for each patch.
...
...
@@ -1644,17 +1644,17 @@ class OwlViTForObjectDetection(OwlViTPreTrainedModel):
>>> # Target image sizes (height, width) to rescale box predictions [batch_size, 2]
>>> target_sizes = torch.Tensor([image.size[::-1]])
>>> # Convert outputs (bounding boxes and class logits) to COCO API
>>> results = processor.post_process(outputs=outputs, target_sizes=target_sizes)
>>> # Convert outputs (bounding boxes and class logits) to final bounding boxes and scores
>>> results = processor.post_process_object_detection(
... outputs=outputs, threshold=0.1, target_sizes=target_sizes
... )
>>> i = 0 # Retrieve predictions for the first image for the corresponding text queries
>>> text = texts[i]
>>> boxes, scores, labels = results[i]["boxes"], results[i]["scores"], results[i]["labels"]
>>> score_threshold = 0.1
>>> for box, score, label in zip(boxes, scores, labels):
... box = [round(i, 2) for i in box.tolist()]
... if score >= score_threshold:
... print(f"Detected {text[label]} with confidence {round(score.item(), 3)} at location {box}")
Detected a photo of a cat with confidence 0.707 at location [324.97, 20.44, 640.58, 373.29]
Detected a photo of a cat with confidence 0.717 at location [1.46, 55.26, 315.55, 472.17]
...
...
src/transformers/models/owlvit/processing_owlvit.py
View file @
cd245780
...
...
@@ -179,6 +179,13 @@ class OwlViTProcessor(ProcessorMixin):
"""
return
self
.
image_processor
.
post_process
(
*
args
,
**
kwargs
)
def
post_process_object_detection
(
self
,
*
args
,
**
kwargs
):
"""
This method forwards all its arguments to [`OwlViTImageProcessor.post_process_object_detection`]. Please refer
to the docstring of this method for more information.
"""
return
self
.
image_processor
.
post_process_object_detection
(
*
args
,
**
kwargs
)
def
post_process_image_guided_detection
(
self
,
*
args
,
**
kwargs
):
"""
This method forwards all its arguments to [`OwlViTImageProcessor.post_process_one_shot_object_detection`].
...
...
src/transformers/pipelines/zero_shot_object_detection.py
View file @
cd245780
...
...
@@ -173,12 +173,11 @@ class ZeroShotObjectDetectionPipeline(ChunkPipeline):
for
model_output
in
model_outputs
:
label
=
model_output
[
"candidate_label"
]
model_output
=
BaseModelOutput
(
model_output
)
outputs
=
self
.
feature_extractor
.
post_process
(
outputs
=
model_output
,
target_sizes
=
model_output
[
"target_size"
]
outputs
=
self
.
feature_extractor
.
post_process
_object_detection
(
outputs
=
model_output
,
threshold
=
threshold
,
target_sizes
=
model_output
[
"target_size"
]
)[
0
]
keep
=
outputs
[
"scores"
]
>=
threshold
for
index
in
keep
.
nonzero
():
for
index
in
outputs
[
"scores"
]
.
nonzero
():
score
=
outputs
[
"scores"
][
index
].
item
()
box
=
self
.
_get_bounding_box
(
outputs
[
"boxes"
][
index
][
0
])
...
...
tests/pipelines/test_pipelines_zero_shot_object_detection.py
View file @
cd245780
...
...
@@ -131,7 +131,8 @@ class ZeroShotObjectDetectionPipelineTests(unittest.TestCase, metaclass=Pipeline
object_detector
=
pipeline
(
"zero-shot-object-detection"
)
outputs
=
object_detector
(
"http://images.cocodataset.org/val2017/000000039769.jpg"
,
candidate_labels
=
[
"cat"
,
"remote"
,
"couch"
]
"http://images.cocodataset.org/val2017/000000039769.jpg"
,
candidate_labels
=
[
"cat"
,
"remote"
,
"couch"
],
)
self
.
assertEqual
(
nested_simplify
(
outputs
,
decimals
=
4
),
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment