Unverified Commit cd245780 authored by Alara Dirik's avatar Alara Dirik Committed by GitHub
Browse files

Improve OWL-ViT postprocessing (#20980)

* add post_process_object_detection method

* style changes
parent e901914d
...@@ -80,7 +80,7 @@ This model was contributed by [adirik](https://huggingface.co/adirik). The origi ...@@ -80,7 +80,7 @@ This model was contributed by [adirik](https://huggingface.co/adirik). The origi
[[autodoc]] OwlViTImageProcessor [[autodoc]] OwlViTImageProcessor
- preprocess - preprocess
- post_process - post_process_object_detection
- post_process_image_guided_detection - post_process_image_guided_detection
## OwlViTFeatureExtractor ## OwlViTFeatureExtractor
......
...@@ -14,7 +14,8 @@ ...@@ -14,7 +14,8 @@
# limitations under the License. # limitations under the License.
"""Image processor class for OwlViT""" """Image processor class for OwlViT"""
from typing import Dict, List, Optional, Union import warnings
from typing import Dict, List, Optional, Tuple, Union
import numpy as np import numpy as np
...@@ -344,6 +345,12 @@ class OwlViTImageProcessor(BaseImageProcessor): ...@@ -344,6 +345,12 @@ class OwlViTImageProcessor(BaseImageProcessor):
in the batch as predicted by the model. in the batch as predicted by the model.
""" """
# TODO: (amy) add support for other frameworks # TODO: (amy) add support for other frameworks
warnings.warn(
"`post_process` is deprecated and will be removed in v5 of Transformers, please use"
" `post_process_object_detection`",
FutureWarning,
)
logits, boxes = outputs.logits, outputs.pred_boxes logits, boxes = outputs.logits, outputs.pred_boxes
if len(logits) != len(target_sizes): if len(logits) != len(target_sizes):
...@@ -367,6 +374,61 @@ class OwlViTImageProcessor(BaseImageProcessor): ...@@ -367,6 +374,61 @@ class OwlViTImageProcessor(BaseImageProcessor):
return results return results
def post_process_object_detection(
self, outputs, threshold: float = 0.1, target_sizes: Union[TensorType, List[Tuple]] = None
):
"""
Converts the raw output of [`OwlViTForObjectDetection`] into final bounding boxes in (top_left_x, top_left_y,
bottom_right_x, bottom_right_y) format.
Args:
outputs ([`OwlViTObjectDetectionOutput`]):
Raw outputs of the model.
threshold (`float`, *optional*):
Score threshold to keep object detection predictions.
target_sizes (`torch.Tensor` or `List[Tuple[int, int]]`, *optional*):
Tensor of shape `(batch_size, 2)` or list of tuples (`Tuple[int, int]`) containing the target size
`(height, width)` of each image in the batch. If unset, predictions will not be resized.
Returns:
`List[Dict]`: A list of dictionaries, each dictionary containing the scores, labels and boxes for an image
in the batch as predicted by the model.
"""
# TODO: (amy) add support for other frameworks
logits, boxes = outputs.logits, outputs.pred_boxes
if target_sizes is not None:
if len(logits) != len(target_sizes):
raise ValueError(
"Make sure that you pass in as many target sizes as the batch dimension of the logits"
)
probs = torch.max(logits, dim=-1)
scores = torch.sigmoid(probs.values)
labels = probs.indices
# Convert to [x0, y0, x1, y1] format
boxes = center_to_corners_format(boxes)
# Convert from relative [0, 1] to absolute [0, height] coordinates
if target_sizes is not None:
if isinstance(target_sizes, List):
img_h = torch.Tensor([i[0] for i in target_sizes])
img_w = torch.Tensor([i[1] for i in target_sizes])
else:
img_h, img_w = target_sizes.unbind(1)
scale_fct = torch.stack([img_w, img_h, img_w, img_h], dim=1)
boxes = boxes * scale_fct[:, None, :]
results = []
for s, l, b in zip(scores, labels, boxes):
score = s[s > threshold]
label = l[s > threshold]
box = b[s > threshold]
results.append({"scores": score, "labels": label, "boxes": box})
return results
# TODO: (Amy) Make compatible with other frameworks # TODO: (Amy) Make compatible with other frameworks
def post_process_image_guided_detection(self, outputs, threshold=0.6, nms_threshold=0.3, target_sizes=None): def post_process_image_guided_detection(self, outputs, threshold=0.6, nms_threshold=0.3, target_sizes=None):
""" """
......
...@@ -204,8 +204,8 @@ class OwlViTObjectDetectionOutput(ModelOutput): ...@@ -204,8 +204,8 @@ class OwlViTObjectDetectionOutput(ModelOutput):
pred_boxes (`torch.FloatTensor` of shape `(batch_size, num_patches, 4)`): pred_boxes (`torch.FloatTensor` of shape `(batch_size, num_patches, 4)`):
Normalized boxes coordinates for all queries, represented as (center_x, center_y, width, height). These Normalized boxes coordinates for all queries, represented as (center_x, center_y, width, height). These
values are normalized in [0, 1], relative to the size of each individual image in the batch (disregarding values are normalized in [0, 1], relative to the size of each individual image in the batch (disregarding
possible padding). You can use [`~OwlViTFeatureExtractor.post_process`] to retrieve the unnormalized possible padding). You can use [`~OwlViTFeatureExtractor.post_process_object_detection`] to retrieve the
bounding boxes. unnormalized bounding boxes.
text_embeds (`torch.FloatTensor` of shape `(batch_size, num_max_text_queries, output_dim`): text_embeds (`torch.FloatTensor` of shape `(batch_size, num_max_text_queries, output_dim`):
The text embeddings obtained by applying the projection layer to the pooled output of [`OwlViTTextModel`]. The text embeddings obtained by applying the projection layer to the pooled output of [`OwlViTTextModel`].
image_embeds (`torch.FloatTensor` of shape `(batch_size, patch_size, patch_size, output_dim`): image_embeds (`torch.FloatTensor` of shape `(batch_size, patch_size, patch_size, output_dim`):
...@@ -248,13 +248,13 @@ class OwlViTImageGuidedObjectDetectionOutput(ModelOutput): ...@@ -248,13 +248,13 @@ class OwlViTImageGuidedObjectDetectionOutput(ModelOutput):
target_pred_boxes (`torch.FloatTensor` of shape `(batch_size, num_patches, 4)`): target_pred_boxes (`torch.FloatTensor` of shape `(batch_size, num_patches, 4)`):
Normalized boxes coordinates for all queries, represented as (center_x, center_y, width, height). These Normalized boxes coordinates for all queries, represented as (center_x, center_y, width, height). These
values are normalized in [0, 1], relative to the size of each individual target image in the batch values are normalized in [0, 1], relative to the size of each individual target image in the batch
(disregarding possible padding). You can use [`~OwlViTFeatureExtractor.post_process`] to retrieve the (disregarding possible padding). You can use [`~OwlViTFeatureExtractor.post_process_object_detection`] to
unnormalized bounding boxes. retrieve the unnormalized bounding boxes.
query_pred_boxes (`torch.FloatTensor` of shape `(batch_size, num_patches, 4)`): query_pred_boxes (`torch.FloatTensor` of shape `(batch_size, num_patches, 4)`):
Normalized boxes coordinates for all queries, represented as (center_x, center_y, width, height). These Normalized boxes coordinates for all queries, represented as (center_x, center_y, width, height). These
values are normalized in [0, 1], relative to the size of each individual query image in the batch values are normalized in [0, 1], relative to the size of each individual query image in the batch
(disregarding possible padding). You can use [`~OwlViTFeatureExtractor.post_process`] to retrieve the (disregarding possible padding). You can use [`~OwlViTFeatureExtractor.post_process_object_detection`] to
unnormalized bounding boxes. retrieve the unnormalized bounding boxes.
image_embeds (`torch.FloatTensor` of shape `(batch_size, patch_size, patch_size, output_dim`): image_embeds (`torch.FloatTensor` of shape `(batch_size, patch_size, patch_size, output_dim`):
Pooled output of [`OwlViTVisionModel`]. OWL-ViT represents images as a set of image patches and computes Pooled output of [`OwlViTVisionModel`]. OWL-ViT represents images as a set of image patches and computes
image embeddings for each patch. image embeddings for each patch.
...@@ -1644,18 +1644,18 @@ class OwlViTForObjectDetection(OwlViTPreTrainedModel): ...@@ -1644,18 +1644,18 @@ class OwlViTForObjectDetection(OwlViTPreTrainedModel):
>>> # Target image sizes (height, width) to rescale box predictions [batch_size, 2] >>> # Target image sizes (height, width) to rescale box predictions [batch_size, 2]
>>> target_sizes = torch.Tensor([image.size[::-1]]) >>> target_sizes = torch.Tensor([image.size[::-1]])
>>> # Convert outputs (bounding boxes and class logits) to COCO API >>> # Convert outputs (bounding boxes and class logits) to final bounding boxes and scores
>>> results = processor.post_process(outputs=outputs, target_sizes=target_sizes) >>> results = processor.post_process_object_detection(
... outputs=outputs, threshold=0.1, target_sizes=target_sizes
... )
>>> i = 0 # Retrieve predictions for the first image for the corresponding text queries >>> i = 0 # Retrieve predictions for the first image for the corresponding text queries
>>> text = texts[i] >>> text = texts[i]
>>> boxes, scores, labels = results[i]["boxes"], results[i]["scores"], results[i]["labels"] >>> boxes, scores, labels = results[i]["boxes"], results[i]["scores"], results[i]["labels"]
>>> score_threshold = 0.1
>>> for box, score, label in zip(boxes, scores, labels): >>> for box, score, label in zip(boxes, scores, labels):
... box = [round(i, 2) for i in box.tolist()] ... box = [round(i, 2) for i in box.tolist()]
... if score >= score_threshold: ... print(f"Detected {text[label]} with confidence {round(score.item(), 3)} at location {box}")
... print(f"Detected {text[label]} with confidence {round(score.item(), 3)} at location {box}")
Detected a photo of a cat with confidence 0.707 at location [324.97, 20.44, 640.58, 373.29] Detected a photo of a cat with confidence 0.707 at location [324.97, 20.44, 640.58, 373.29]
Detected a photo of a cat with confidence 0.717 at location [1.46, 55.26, 315.55, 472.17] Detected a photo of a cat with confidence 0.717 at location [1.46, 55.26, 315.55, 472.17]
```""" ```"""
......
...@@ -179,6 +179,13 @@ class OwlViTProcessor(ProcessorMixin): ...@@ -179,6 +179,13 @@ class OwlViTProcessor(ProcessorMixin):
""" """
return self.image_processor.post_process(*args, **kwargs) return self.image_processor.post_process(*args, **kwargs)
def post_process_object_detection(self, *args, **kwargs):
"""
This method forwards all its arguments to [`OwlViTImageProcessor.post_process_object_detection`]. Please refer
to the docstring of this method for more information.
"""
return self.image_processor.post_process_object_detection(*args, **kwargs)
def post_process_image_guided_detection(self, *args, **kwargs): def post_process_image_guided_detection(self, *args, **kwargs):
""" """
This method forwards all its arguments to [`OwlViTImageProcessor.post_process_one_shot_object_detection`]. This method forwards all its arguments to [`OwlViTImageProcessor.post_process_one_shot_object_detection`].
......
...@@ -173,12 +173,11 @@ class ZeroShotObjectDetectionPipeline(ChunkPipeline): ...@@ -173,12 +173,11 @@ class ZeroShotObjectDetectionPipeline(ChunkPipeline):
for model_output in model_outputs: for model_output in model_outputs:
label = model_output["candidate_label"] label = model_output["candidate_label"]
model_output = BaseModelOutput(model_output) model_output = BaseModelOutput(model_output)
outputs = self.feature_extractor.post_process( outputs = self.feature_extractor.post_process_object_detection(
outputs=model_output, target_sizes=model_output["target_size"] outputs=model_output, threshold=threshold, target_sizes=model_output["target_size"]
)[0] )[0]
keep = outputs["scores"] >= threshold
for index in keep.nonzero(): for index in outputs["scores"].nonzero():
score = outputs["scores"][index].item() score = outputs["scores"][index].item()
box = self._get_bounding_box(outputs["boxes"][index][0]) box = self._get_bounding_box(outputs["boxes"][index][0])
......
...@@ -131,7 +131,8 @@ class ZeroShotObjectDetectionPipelineTests(unittest.TestCase, metaclass=Pipeline ...@@ -131,7 +131,8 @@ class ZeroShotObjectDetectionPipelineTests(unittest.TestCase, metaclass=Pipeline
object_detector = pipeline("zero-shot-object-detection") object_detector = pipeline("zero-shot-object-detection")
outputs = object_detector( outputs = object_detector(
"http://images.cocodataset.org/val2017/000000039769.jpg", candidate_labels=["cat", "remote", "couch"] "http://images.cocodataset.org/val2017/000000039769.jpg",
candidate_labels=["cat", "remote", "couch"],
) )
self.assertEqual( self.assertEqual(
nested_simplify(outputs, decimals=4), nested_simplify(outputs, decimals=4),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment