Unverified Commit 002915aa authored by Alara Dirik's avatar Alara Dirik Committed by GitHub
Browse files

Owlvit docs test (#18257)

* fix docs and add owlvit docs test

* fix minor bug in post_process, add to processor

* improve owlvit code examples

* fix hardcoded image size
parent d32558cc
...@@ -39,19 +39,26 @@ OWL-ViT is a zero-shot text-conditioned object detection model. OWL-ViT uses [CL ...@@ -39,19 +39,26 @@ OWL-ViT is a zero-shot text-conditioned object detection model. OWL-ViT uses [CL
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
>>> image = Image.open(requests.get(url, stream=True).raw) >>> image = Image.open(requests.get(url, stream=True).raw)
>>> texts = [["a photo of a cat", "a photo of a dog"]]
>>> inputs = processor(text=[["a photo of a cat", "a photo of a dog"]], images=image, return_tensors="pt") >>> inputs = processor(text=texts, images=image, return_tensors="pt")
>>> outputs = model(**inputs) >>> outputs = model(**inputs)
>>> logits = outputs["logits"] # Prediction logits of shape [batch_size, num_patches, num_max_text_queries]
>>> boxes = outputs["pred_boxes"] # Object box boundaries of shape [batch_size, num_patches, 4] >>> # Target image sizes (height, width) to rescale box predictions [batch_size, 2]
>>> target_sizes = torch.Tensor([image.size[::-1]])
>>> batch_size = boxes.shape[0] >>> # Convert outputs (bounding boxes and class logits) to COCO API
>>> for i in range(batch_size): # Loop over sets of images and text queries >>> results = processor.post_process(outputs=outputs, target_sizes=target_sizes)
... boxes = outputs["pred_boxes"][i]
... logits = torch.max(outputs["logits"][i], dim=-1) >>> i = 0 # Retrieve predictions for the first image for the corresponding text queries
... scores = torch.sigmoid(logits.values) >>> text = texts[i]
... labels = logits.indices >>> boxes, scores, labels = results[i]["boxes"], results[i]["scores"], results[i]["labels"]
>>> score_threshold = 0.1
>>> for box, score, label in zip(boxes, scores, labels):
... box = [round(i, 2) for i in box.tolist()]
... if score >= score_threshold:
... print(f"Detected {text[label]} with confidence {round(score.item(), 3)} at location {box}")
Detected a photo of a cat with confidence 0.243 at location [1.42, 50.69, 308.58, 370.48]
Detected a photo of a cat with confidence 0.298 at location [348.06, 20.56, 642.33, 372.61]
``` ```
This model was contributed by [adirik](https://huggingface.co/adirik). The original code can be found [here](https://github.com/google-research/scenic/tree/main/scenic/projects/owl_vit). This model was contributed by [adirik](https://huggingface.co/adirik). The original code can be found [here](https://github.com/google-research/scenic/tree/main/scenic/projects/owl_vit).
......
...@@ -26,7 +26,6 @@ from ...utils import TensorType, is_torch_available, logging ...@@ -26,7 +26,6 @@ from ...utils import TensorType, is_torch_available, logging
if is_torch_available(): if is_torch_available():
import torch import torch
from torch import nn
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
...@@ -109,18 +108,19 @@ class OwlViTFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionMixin ...@@ -109,18 +108,19 @@ class OwlViTFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionMixin
`List[Dict]`: A list of dictionaries, each dictionary containing the scores, labels and boxes for an image `List[Dict]`: A list of dictionaries, each dictionary containing the scores, labels and boxes for an image
in the batch as predicted by the model. in the batch as predicted by the model.
""" """
out_logits, out_bbox = outputs.logits, outputs.pred_boxes logits, boxes = outputs.logits, outputs.pred_boxes
if len(out_logits) != len(target_sizes): if len(logits) != len(target_sizes):
raise ValueError("Make sure that you pass in as many target sizes as the batch dimension of the logits") raise ValueError("Make sure that you pass in as many target sizes as the batch dimension of the logits")
if target_sizes.shape[1] != 2: if target_sizes.shape[1] != 2:
raise ValueError("Each element of target_sizes must contain the size (h, w) of each image of the batch") raise ValueError("Each element of target_sizes must contain the size (h, w) of each image of the batch")
prob = nn.functional.softmax(out_logits, -1) probs = torch.max(logits, dim=-1)
scores, labels = prob[..., :-1].max(-1) scores = torch.sigmoid(probs.values)
labels = probs.indices
# Convert to [x0, y0, x1, y1] format # Convert to [x0, y0, x1, y1] format
boxes = center_to_corners_format(out_bbox) boxes = center_to_corners_format(boxes)
# Convert from relative [0, 1] to absolute [0, height] coordinates # Convert from relative [0, 1] to absolute [0, height] coordinates
img_h, img_w = target_sizes.unbind(1) img_h, img_w = target_sizes.unbind(1)
......
...@@ -1300,23 +1300,31 @@ class OwlViTForObjectDetection(OwlViTPreTrainedModel): ...@@ -1300,23 +1300,31 @@ class OwlViTForObjectDetection(OwlViTPreTrainedModel):
>>> import torch >>> import torch
>>> from transformers import OwlViTProcessor, OwlViTForObjectDetection >>> from transformers import OwlViTProcessor, OwlViTForObjectDetection
>>> model = OwlViTModel.from_pretrained("google/owlvit-base-patch32")
>>> processor = OwlViTProcessor.from_pretrained("google/owlvit-base-patch32") >>> processor = OwlViTProcessor.from_pretrained("google/owlvit-base-patch32")
>>> model = OwlViTForObjectDetection.from_pretrained("google/owlvit-base-patch32")
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
>>> image = Image.open(requests.get(url, stream=True).raw) >>> image = Image.open(requests.get(url, stream=True).raw)
>>> texts = [["a photo of a cat", "a photo of a dog"]]
>>> inputs = processor(text=[["a photo of a cat", "a photo of a dog"]], images=image, return_tensors="pt") >>> inputs = processor(text=texts, images=image, return_tensors="pt")
>>> outputs = model(**inputs) >>> outputs = model(**inputs)
>>> logits = outputs["logits"] # Prediction logits of shape [batch_size, num_patches, num_max_text_queries]
>>> boxes = outputs["pred_boxes"] # Object box boundaries of shape # [batch_size, num_patches, 4] >>> # Target image sizes (height, width) to rescale box predictions [batch_size, 2]
>>> target_sizes = torch.Tensor([image.size[::-1]])
>>> batch_size = boxes.shape[0] >>> # Convert outputs (bounding boxes and class logits) to COCO API
>>> for i in range(batch_size): # Loop over sets of images and text queries >>> results = processor.post_process(outputs=outputs, target_sizes=target_sizes)
... boxes = outputs["pred_boxes"][i]
... logits = torch.max(outputs["logits"][i], dim=-1) >>> i = 0 # Retrieve predictions for the first image for the corresponding text queries
... scores = torch.sigmoid(logits.values) >>> text = texts[i]
... labels = logits.indices >>> boxes, scores, labels = results[i]["boxes"], results[i]["scores"], results[i]["labels"]
>>> score_threshold = 0.1
>>> for box, score, label in zip(boxes, scores, labels):
... box = [round(i, 2) for i in box.tolist()]
... if score >= score_threshold:
... print(f"Detected {text[label]} with confidence {round(score.item(), 3)} at location {box}")
Detected a photo of a cat with confidence 0.243 at location [1.42, 50.69, 308.58, 370.48]
Detected a photo of a cat with confidence 0.298 at location [348.06, 20.56, 642.33, 372.61]
```""" ```"""
output_hidden_states = ( output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
......
...@@ -139,6 +139,13 @@ class OwlViTProcessor(ProcessorMixin): ...@@ -139,6 +139,13 @@ class OwlViTProcessor(ProcessorMixin):
else: else:
return BatchEncoding(data=dict(**image_features), tensor_type=return_tensors) return BatchEncoding(data=dict(**image_features), tensor_type=return_tensors)
def post_process(self, *args, **kwargs):
"""
This method forwards all its arguments to [`OwlViTFeatureExtractor.post_process`]. Please refer to the
docstring of this method for more information.
"""
return self.feature_extractor.post_process(*args, **kwargs)
def batch_decode(self, *args, **kwargs): def batch_decode(self, *args, **kwargs):
""" """
This method forwards all its arguments to CLIPTokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please This method forwards all its arguments to CLIPTokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please
......
...@@ -48,6 +48,7 @@ src/transformers/models/mobilevit/modeling_mobilevit.py ...@@ -48,6 +48,7 @@ src/transformers/models/mobilevit/modeling_mobilevit.py
src/transformers/models/opt/modeling_opt.py src/transformers/models/opt/modeling_opt.py
src/transformers/models/opt/modeling_tf_opt.py src/transformers/models/opt/modeling_tf_opt.py
src/transformers/models/opt/modeling_flax_opt.py src/transformers/models/opt/modeling_flax_opt.py
src/transformers/models/owlvit/modeling_owlvit.py
src/transformers/models/pegasus/modeling_pegasus.py src/transformers/models/pegasus/modeling_pegasus.py
src/transformers/models/plbart/modeling_plbart.py src/transformers/models/plbart/modeling_plbart.py
src/transformers/models/poolformer/modeling_poolformer.py src/transformers/models/poolformer/modeling_poolformer.py
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment