"vscode:/vscode.git/clone" did not exist on "9c6aeba3535898d06ffdbd9fff7fca093ec62fc2"
Unverified Commit cd245780 authored by Alara Dirik's avatar Alara Dirik Committed by GitHub
Browse files

Improve OWL-ViT postprocessing (#20980)

* add post_process_object_detection method

* style changes
parent e901914d
......@@ -80,7 +80,7 @@ This model was contributed by [adirik](https://huggingface.co/adirik). The origi
[[autodoc]] OwlViTImageProcessor
- preprocess
- post_process
- post_process_object_detection
- post_process_image_guided_detection
## OwlViTFeatureExtractor
......
......@@ -14,7 +14,8 @@
# limitations under the License.
"""Image processor class for OwlViT"""
from typing import Dict, List, Optional, Union
import warnings
from typing import Dict, List, Optional, Tuple, Union
import numpy as np
......@@ -344,6 +345,12 @@ class OwlViTImageProcessor(BaseImageProcessor):
in the batch as predicted by the model.
"""
# TODO: (amy) add support for other frameworks
warnings.warn(
"`post_process` is deprecated and will be removed in v5 of Transformers, please use"
" `post_process_object_detection`",
FutureWarning,
)
logits, boxes = outputs.logits, outputs.pred_boxes
if len(logits) != len(target_sizes):
......@@ -367,6 +374,61 @@ class OwlViTImageProcessor(BaseImageProcessor):
return results
def post_process_object_detection(
self, outputs, threshold: float = 0.1, target_sizes: Union[TensorType, List[Tuple]] = None
):
"""
Converts the raw output of [`OwlViTForObjectDetection`] into final bounding boxes in (top_left_x, top_left_y,
bottom_right_x, bottom_right_y) format.
Args:
outputs ([`OwlViTObjectDetectionOutput`]):
Raw outputs of the model.
threshold (`float`, *optional*):
Score threshold to keep object detection predictions.
target_sizes (`torch.Tensor` or `List[Tuple[int, int]]`, *optional*):
Tensor of shape `(batch_size, 2)` or list of tuples (`Tuple[int, int]`) containing the target size
`(height, width)` of each image in the batch. If unset, predictions will not be resized.
Returns:
`List[Dict]`: A list of dictionaries, each dictionary containing the scores, labels and boxes for an image
in the batch as predicted by the model.
"""
# TODO: (amy) add support for other frameworks
logits, boxes = outputs.logits, outputs.pred_boxes
if target_sizes is not None:
if len(logits) != len(target_sizes):
raise ValueError(
"Make sure that you pass in as many target sizes as the batch dimension of the logits"
)
probs = torch.max(logits, dim=-1)
scores = torch.sigmoid(probs.values)
labels = probs.indices
# Convert to [x0, y0, x1, y1] format
boxes = center_to_corners_format(boxes)
# Convert from relative [0, 1] to absolute [0, height] coordinates
if target_sizes is not None:
if isinstance(target_sizes, List):
img_h = torch.Tensor([i[0] for i in target_sizes])
img_w = torch.Tensor([i[1] for i in target_sizes])
else:
img_h, img_w = target_sizes.unbind(1)
scale_fct = torch.stack([img_w, img_h, img_w, img_h], dim=1)
boxes = boxes * scale_fct[:, None, :]
results = []
for s, l, b in zip(scores, labels, boxes):
score = s[s > threshold]
label = l[s > threshold]
box = b[s > threshold]
results.append({"scores": score, "labels": label, "boxes": box})
return results
# TODO: (Amy) Make compatible with other frameworks
def post_process_image_guided_detection(self, outputs, threshold=0.6, nms_threshold=0.3, target_sizes=None):
"""
......
......@@ -204,8 +204,8 @@ class OwlViTObjectDetectionOutput(ModelOutput):
pred_boxes (`torch.FloatTensor` of shape `(batch_size, num_patches, 4)`):
Normalized boxes coordinates for all queries, represented as (center_x, center_y, width, height). These
values are normalized in [0, 1], relative to the size of each individual image in the batch (disregarding
possible padding). You can use [`~OwlViTFeatureExtractor.post_process`] to retrieve the unnormalized
bounding boxes.
possible padding). You can use [`~OwlViTFeatureExtractor.post_process_object_detection`] to retrieve the
unnormalized bounding boxes.
text_embeds (`torch.FloatTensor` of shape `(batch_size, num_max_text_queries, output_dim`):
The text embeddings obtained by applying the projection layer to the pooled output of [`OwlViTTextModel`].
image_embeds (`torch.FloatTensor` of shape `(batch_size, patch_size, patch_size, output_dim`):
......@@ -248,13 +248,13 @@ class OwlViTImageGuidedObjectDetectionOutput(ModelOutput):
target_pred_boxes (`torch.FloatTensor` of shape `(batch_size, num_patches, 4)`):
Normalized boxes coordinates for all queries, represented as (center_x, center_y, width, height). These
values are normalized in [0, 1], relative to the size of each individual target image in the batch
(disregarding possible padding). You can use [`~OwlViTFeatureExtractor.post_process`] to retrieve the
unnormalized bounding boxes.
(disregarding possible padding). You can use [`~OwlViTFeatureExtractor.post_process_object_detection`] to
retrieve the unnormalized bounding boxes.
query_pred_boxes (`torch.FloatTensor` of shape `(batch_size, num_patches, 4)`):
Normalized boxes coordinates for all queries, represented as (center_x, center_y, width, height). These
values are normalized in [0, 1], relative to the size of each individual query image in the batch
(disregarding possible padding). You can use [`~OwlViTFeatureExtractor.post_process`] to retrieve the
unnormalized bounding boxes.
(disregarding possible padding). You can use [`~OwlViTFeatureExtractor.post_process_object_detection`] to
retrieve the unnormalized bounding boxes.
image_embeds (`torch.FloatTensor` of shape `(batch_size, patch_size, patch_size, output_dim`):
Pooled output of [`OwlViTVisionModel`]. OWL-ViT represents images as a set of image patches and computes
image embeddings for each patch.
......@@ -1644,18 +1644,18 @@ class OwlViTForObjectDetection(OwlViTPreTrainedModel):
>>> # Target image sizes (height, width) to rescale box predictions [batch_size, 2]
>>> target_sizes = torch.Tensor([image.size[::-1]])
>>> # Convert outputs (bounding boxes and class logits) to COCO API
>>> results = processor.post_process(outputs=outputs, target_sizes=target_sizes)
>>> # Convert outputs (bounding boxes and class logits) to final bounding boxes and scores
>>> results = processor.post_process_object_detection(
... outputs=outputs, threshold=0.1, target_sizes=target_sizes
... )
>>> i = 0 # Retrieve predictions for the first image for the corresponding text queries
>>> text = texts[i]
>>> boxes, scores, labels = results[i]["boxes"], results[i]["scores"], results[i]["labels"]
>>> score_threshold = 0.1
>>> for box, score, label in zip(boxes, scores, labels):
... box = [round(i, 2) for i in box.tolist()]
... if score >= score_threshold:
... print(f"Detected {text[label]} with confidence {round(score.item(), 3)} at location {box}")
... print(f"Detected {text[label]} with confidence {round(score.item(), 3)} at location {box}")
Detected a photo of a cat with confidence 0.707 at location [324.97, 20.44, 640.58, 373.29]
Detected a photo of a cat with confidence 0.717 at location [1.46, 55.26, 315.55, 472.17]
```"""
......
......@@ -179,6 +179,13 @@ class OwlViTProcessor(ProcessorMixin):
"""
return self.image_processor.post_process(*args, **kwargs)
def post_process_object_detection(self, *args, **kwargs):
"""
This method forwards all its arguments to [`OwlViTImageProcessor.post_process_object_detection`]. Please refer
to the docstring of this method for more information.
"""
return self.image_processor.post_process_object_detection(*args, **kwargs)
def post_process_image_guided_detection(self, *args, **kwargs):
"""
This method forwards all its arguments to [`OwlViTImageProcessor.post_process_one_shot_object_detection`].
......
......@@ -173,12 +173,11 @@ class ZeroShotObjectDetectionPipeline(ChunkPipeline):
for model_output in model_outputs:
label = model_output["candidate_label"]
model_output = BaseModelOutput(model_output)
outputs = self.feature_extractor.post_process(
outputs=model_output, target_sizes=model_output["target_size"]
outputs = self.feature_extractor.post_process_object_detection(
outputs=model_output, threshold=threshold, target_sizes=model_output["target_size"]
)[0]
keep = outputs["scores"] >= threshold
for index in keep.nonzero():
for index in outputs["scores"].nonzero():
score = outputs["scores"][index].item()
box = self._get_bounding_box(outputs["boxes"][index][0])
......
......@@ -131,7 +131,8 @@ class ZeroShotObjectDetectionPipelineTests(unittest.TestCase, metaclass=Pipeline
object_detector = pipeline("zero-shot-object-detection")
outputs = object_detector(
"http://images.cocodataset.org/val2017/000000039769.jpg", candidate_labels=["cat", "remote", "couch"]
"http://images.cocodataset.org/val2017/000000039769.jpg",
candidate_labels=["cat", "remote", "couch"],
)
self.assertEqual(
nested_simplify(outputs, decimals=4),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment