Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
cd245780
Unverified
Commit
cd245780
authored
Jan 03, 2023
by
Alara Dirik
Committed by
GitHub
Jan 03, 2023
Browse files
Improve OWL-ViT postprocessing (#20980)
* add post_process_object_detection method * style changes
parent
e901914d
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
87 additions
and
18 deletions
+87
-18
docs/source/en/model_doc/owlvit.mdx
docs/source/en/model_doc/owlvit.mdx
+1
-1
src/transformers/models/owlvit/image_processing_owlvit.py
src/transformers/models/owlvit/image_processing_owlvit.py
+63
-1
src/transformers/models/owlvit/modeling_owlvit.py
src/transformers/models/owlvit/modeling_owlvit.py
+11
-11
src/transformers/models/owlvit/processing_owlvit.py
src/transformers/models/owlvit/processing_owlvit.py
+7
-0
src/transformers/pipelines/zero_shot_object_detection.py
src/transformers/pipelines/zero_shot_object_detection.py
+3
-4
tests/pipelines/test_pipelines_zero_shot_object_detection.py
tests/pipelines/test_pipelines_zero_shot_object_detection.py
+2
-1
No files found.
docs/source/en/model_doc/owlvit.mdx
View file @
cd245780
...
@@ -80,7 +80,7 @@ This model was contributed by [adirik](https://huggingface.co/adirik). The origi
...
@@ -80,7 +80,7 @@ This model was contributed by [adirik](https://huggingface.co/adirik). The origi
[[autodoc]] OwlViTImageProcessor
[[autodoc]] OwlViTImageProcessor
- preprocess
- preprocess
- post_process
- post_process
_object_detection
- post_process_image_guided_detection
- post_process_image_guided_detection
## OwlViTFeatureExtractor
## OwlViTFeatureExtractor
...
...
src/transformers/models/owlvit/image_processing_owlvit.py
View file @
cd245780
...
@@ -14,7 +14,8 @@
...
@@ -14,7 +14,8 @@
# limitations under the License.
# limitations under the License.
"""Image processor class for OwlViT"""
"""Image processor class for OwlViT"""
from
typing
import
Dict
,
List
,
Optional
,
Union
import
warnings
from
typing
import
Dict
,
List
,
Optional
,
Tuple
,
Union
import
numpy
as
np
import
numpy
as
np
...
@@ -344,6 +345,12 @@ class OwlViTImageProcessor(BaseImageProcessor):
...
@@ -344,6 +345,12 @@ class OwlViTImageProcessor(BaseImageProcessor):
in the batch as predicted by the model.
in the batch as predicted by the model.
"""
"""
# TODO: (amy) add support for other frameworks
# TODO: (amy) add support for other frameworks
warnings
.
warn
(
"`post_process` is deprecated and will be removed in v5 of Transformers, please use"
" `post_process_object_detection`"
,
FutureWarning
,
)
logits
,
boxes
=
outputs
.
logits
,
outputs
.
pred_boxes
logits
,
boxes
=
outputs
.
logits
,
outputs
.
pred_boxes
if
len
(
logits
)
!=
len
(
target_sizes
):
if
len
(
logits
)
!=
len
(
target_sizes
):
...
@@ -367,6 +374,61 @@ class OwlViTImageProcessor(BaseImageProcessor):
...
@@ -367,6 +374,61 @@ class OwlViTImageProcessor(BaseImageProcessor):
return
results
return
results
def
post_process_object_detection
(
self
,
outputs
,
threshold
:
float
=
0.1
,
target_sizes
:
Union
[
TensorType
,
List
[
Tuple
]]
=
None
):
"""
Converts the raw output of [`OwlViTForObjectDetection`] into final bounding boxes in (top_left_x, top_left_y,
bottom_right_x, bottom_right_y) format.
Args:
outputs ([`OwlViTObjectDetectionOutput`]):
Raw outputs of the model.
threshold (`float`, *optional*):
Score threshold to keep object detection predictions.
target_sizes (`torch.Tensor` or `List[Tuple[int, int]]`, *optional*):
Tensor of shape `(batch_size, 2)` or list of tuples (`Tuple[int, int]`) containing the target size
`(height, width)` of each image in the batch. If unset, predictions will not be resized.
Returns:
`List[Dict]`: A list of dictionaries, each dictionary containing the scores, labels and boxes for an image
in the batch as predicted by the model.
"""
# TODO: (amy) add support for other frameworks
logits
,
boxes
=
outputs
.
logits
,
outputs
.
pred_boxes
if
target_sizes
is
not
None
:
if
len
(
logits
)
!=
len
(
target_sizes
):
raise
ValueError
(
"Make sure that you pass in as many target sizes as the batch dimension of the logits"
)
probs
=
torch
.
max
(
logits
,
dim
=-
1
)
scores
=
torch
.
sigmoid
(
probs
.
values
)
labels
=
probs
.
indices
# Convert to [x0, y0, x1, y1] format
boxes
=
center_to_corners_format
(
boxes
)
# Convert from relative [0, 1] to absolute [0, height] coordinates
if
target_sizes
is
not
None
:
if
isinstance
(
target_sizes
,
List
):
img_h
=
torch
.
Tensor
([
i
[
0
]
for
i
in
target_sizes
])
img_w
=
torch
.
Tensor
([
i
[
1
]
for
i
in
target_sizes
])
else
:
img_h
,
img_w
=
target_sizes
.
unbind
(
1
)
scale_fct
=
torch
.
stack
([
img_w
,
img_h
,
img_w
,
img_h
],
dim
=
1
)
boxes
=
boxes
*
scale_fct
[:,
None
,
:]
results
=
[]
for
s
,
l
,
b
in
zip
(
scores
,
labels
,
boxes
):
score
=
s
[
s
>
threshold
]
label
=
l
[
s
>
threshold
]
box
=
b
[
s
>
threshold
]
results
.
append
({
"scores"
:
score
,
"labels"
:
label
,
"boxes"
:
box
})
return
results
# TODO: (Amy) Make compatible with other frameworks
# TODO: (Amy) Make compatible with other frameworks
def
post_process_image_guided_detection
(
self
,
outputs
,
threshold
=
0.6
,
nms_threshold
=
0.3
,
target_sizes
=
None
):
def
post_process_image_guided_detection
(
self
,
outputs
,
threshold
=
0.6
,
nms_threshold
=
0.3
,
target_sizes
=
None
):
"""
"""
...
...
src/transformers/models/owlvit/modeling_owlvit.py
View file @
cd245780
...
@@ -204,8 +204,8 @@ class OwlViTObjectDetectionOutput(ModelOutput):
...
@@ -204,8 +204,8 @@ class OwlViTObjectDetectionOutput(ModelOutput):
pred_boxes (`torch.FloatTensor` of shape `(batch_size, num_patches, 4)`):
pred_boxes (`torch.FloatTensor` of shape `(batch_size, num_patches, 4)`):
Normalized boxes coordinates for all queries, represented as (center_x, center_y, width, height). These
Normalized boxes coordinates for all queries, represented as (center_x, center_y, width, height). These
values are normalized in [0, 1], relative to the size of each individual image in the batch (disregarding
values are normalized in [0, 1], relative to the size of each individual image in the batch (disregarding
possible padding). You can use [`~OwlViTFeatureExtractor.post_process`] to retrieve the
unnormalized
possible padding). You can use [`~OwlViTFeatureExtractor.post_process
_object_detection
`] to retrieve the
bounding boxes.
unnormalized
bounding boxes.
text_embeds (`torch.FloatTensor` of shape `(batch_size, num_max_text_queries, output_dim`):
text_embeds (`torch.FloatTensor` of shape `(batch_size, num_max_text_queries, output_dim`):
The text embeddings obtained by applying the projection layer to the pooled output of [`OwlViTTextModel`].
The text embeddings obtained by applying the projection layer to the pooled output of [`OwlViTTextModel`].
image_embeds (`torch.FloatTensor` of shape `(batch_size, patch_size, patch_size, output_dim`):
image_embeds (`torch.FloatTensor` of shape `(batch_size, patch_size, patch_size, output_dim`):
...
@@ -248,13 +248,13 @@ class OwlViTImageGuidedObjectDetectionOutput(ModelOutput):
...
@@ -248,13 +248,13 @@ class OwlViTImageGuidedObjectDetectionOutput(ModelOutput):
target_pred_boxes (`torch.FloatTensor` of shape `(batch_size, num_patches, 4)`):
target_pred_boxes (`torch.FloatTensor` of shape `(batch_size, num_patches, 4)`):
Normalized boxes coordinates for all queries, represented as (center_x, center_y, width, height). These
Normalized boxes coordinates for all queries, represented as (center_x, center_y, width, height). These
values are normalized in [0, 1], relative to the size of each individual target image in the batch
values are normalized in [0, 1], relative to the size of each individual target image in the batch
(disregarding possible padding). You can use [`~OwlViTFeatureExtractor.post_process
`] to retrieve the
(disregarding possible padding). You can use [`~OwlViTFeatureExtractor.post_process
_object_detection`] to
unnormalized bounding boxes.
retrieve the
unnormalized bounding boxes.
query_pred_boxes (`torch.FloatTensor` of shape `(batch_size, num_patches, 4)`):
query_pred_boxes (`torch.FloatTensor` of shape `(batch_size, num_patches, 4)`):
Normalized boxes coordinates for all queries, represented as (center_x, center_y, width, height). These
Normalized boxes coordinates for all queries, represented as (center_x, center_y, width, height). These
values are normalized in [0, 1], relative to the size of each individual query image in the batch
values are normalized in [0, 1], relative to the size of each individual query image in the batch
(disregarding possible padding). You can use [`~OwlViTFeatureExtractor.post_process
`] to retrieve the
(disregarding possible padding). You can use [`~OwlViTFeatureExtractor.post_process
_object_detection`] to
unnormalized bounding boxes.
retrieve the
unnormalized bounding boxes.
image_embeds (`torch.FloatTensor` of shape `(batch_size, patch_size, patch_size, output_dim`):
image_embeds (`torch.FloatTensor` of shape `(batch_size, patch_size, patch_size, output_dim`):
Pooled output of [`OwlViTVisionModel`]. OWL-ViT represents images as a set of image patches and computes
Pooled output of [`OwlViTVisionModel`]. OWL-ViT represents images as a set of image patches and computes
image embeddings for each patch.
image embeddings for each patch.
...
@@ -1644,18 +1644,18 @@ class OwlViTForObjectDetection(OwlViTPreTrainedModel):
...
@@ -1644,18 +1644,18 @@ class OwlViTForObjectDetection(OwlViTPreTrainedModel):
>>> # Target image sizes (height, width) to rescale box predictions [batch_size, 2]
>>> # Target image sizes (height, width) to rescale box predictions [batch_size, 2]
>>> target_sizes = torch.Tensor([image.size[::-1]])
>>> target_sizes = torch.Tensor([image.size[::-1]])
>>> # Convert outputs (bounding boxes and class logits) to COCO API
>>> # Convert outputs (bounding boxes and class logits) to final bounding boxes and scores
>>> results = processor.post_process(outputs=outputs, target_sizes=target_sizes)
>>> results = processor.post_process_object_detection(
... outputs=outputs, threshold=0.1, target_sizes=target_sizes
... )
>>> i = 0 # Retrieve predictions for the first image for the corresponding text queries
>>> i = 0 # Retrieve predictions for the first image for the corresponding text queries
>>> text = texts[i]
>>> text = texts[i]
>>> boxes, scores, labels = results[i]["boxes"], results[i]["scores"], results[i]["labels"]
>>> boxes, scores, labels = results[i]["boxes"], results[i]["scores"], results[i]["labels"]
>>> score_threshold = 0.1
>>> for box, score, label in zip(boxes, scores, labels):
>>> for box, score, label in zip(boxes, scores, labels):
... box = [round(i, 2) for i in box.tolist()]
... box = [round(i, 2) for i in box.tolist()]
... if score >= score_threshold:
... print(f"Detected {text[label]} with confidence {round(score.item(), 3)} at location {box}")
... print(f"Detected {text[label]} with confidence {round(score.item(), 3)} at location {box}")
Detected a photo of a cat with confidence 0.707 at location [324.97, 20.44, 640.58, 373.29]
Detected a photo of a cat with confidence 0.707 at location [324.97, 20.44, 640.58, 373.29]
Detected a photo of a cat with confidence 0.717 at location [1.46, 55.26, 315.55, 472.17]
Detected a photo of a cat with confidence 0.717 at location [1.46, 55.26, 315.55, 472.17]
```"""
```"""
...
...
src/transformers/models/owlvit/processing_owlvit.py
View file @
cd245780
...
@@ -179,6 +179,13 @@ class OwlViTProcessor(ProcessorMixin):
...
@@ -179,6 +179,13 @@ class OwlViTProcessor(ProcessorMixin):
"""
"""
return
self
.
image_processor
.
post_process
(
*
args
,
**
kwargs
)
return
self
.
image_processor
.
post_process
(
*
args
,
**
kwargs
)
def
post_process_object_detection
(
self
,
*
args
,
**
kwargs
):
"""
This method forwards all its arguments to [`OwlViTImageProcessor.post_process_object_detection`]. Please refer
to the docstring of this method for more information.
"""
return
self
.
image_processor
.
post_process_object_detection
(
*
args
,
**
kwargs
)
def
post_process_image_guided_detection
(
self
,
*
args
,
**
kwargs
):
def
post_process_image_guided_detection
(
self
,
*
args
,
**
kwargs
):
"""
"""
This method forwards all its arguments to [`OwlViTImageProcessor.post_process_one_shot_object_detection`].
This method forwards all its arguments to [`OwlViTImageProcessor.post_process_one_shot_object_detection`].
...
...
src/transformers/pipelines/zero_shot_object_detection.py
View file @
cd245780
...
@@ -173,12 +173,11 @@ class ZeroShotObjectDetectionPipeline(ChunkPipeline):
...
@@ -173,12 +173,11 @@ class ZeroShotObjectDetectionPipeline(ChunkPipeline):
for
model_output
in
model_outputs
:
for
model_output
in
model_outputs
:
label
=
model_output
[
"candidate_label"
]
label
=
model_output
[
"candidate_label"
]
model_output
=
BaseModelOutput
(
model_output
)
model_output
=
BaseModelOutput
(
model_output
)
outputs
=
self
.
feature_extractor
.
post_process
(
outputs
=
self
.
feature_extractor
.
post_process
_object_detection
(
outputs
=
model_output
,
target_sizes
=
model_output
[
"target_size"
]
outputs
=
model_output
,
threshold
=
threshold
,
target_sizes
=
model_output
[
"target_size"
]
)[
0
]
)[
0
]
keep
=
outputs
[
"scores"
]
>=
threshold
for
index
in
keep
.
nonzero
():
for
index
in
outputs
[
"scores"
]
.
nonzero
():
score
=
outputs
[
"scores"
][
index
].
item
()
score
=
outputs
[
"scores"
][
index
].
item
()
box
=
self
.
_get_bounding_box
(
outputs
[
"boxes"
][
index
][
0
])
box
=
self
.
_get_bounding_box
(
outputs
[
"boxes"
][
index
][
0
])
...
...
tests/pipelines/test_pipelines_zero_shot_object_detection.py
View file @
cd245780
...
@@ -131,7 +131,8 @@ class ZeroShotObjectDetectionPipelineTests(unittest.TestCase, metaclass=Pipeline
...
@@ -131,7 +131,8 @@ class ZeroShotObjectDetectionPipelineTests(unittest.TestCase, metaclass=Pipeline
object_detector
=
pipeline
(
"zero-shot-object-detection"
)
object_detector
=
pipeline
(
"zero-shot-object-detection"
)
outputs
=
object_detector
(
outputs
=
object_detector
(
"http://images.cocodataset.org/val2017/000000039769.jpg"
,
candidate_labels
=
[
"cat"
,
"remote"
,
"couch"
]
"http://images.cocodataset.org/val2017/000000039769.jpg"
,
candidate_labels
=
[
"cat"
,
"remote"
,
"couch"
],
)
)
self
.
assertEqual
(
self
.
assertEqual
(
nested_simplify
(
outputs
,
decimals
=
4
),
nested_simplify
(
outputs
,
decimals
=
4
),
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment