"...git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "8d2fca07e85af51f50e297d14e99318c1f665a9c"
Unverified Commit a00b7e85 authored by Alara Dirik's avatar Alara Dirik Committed by GitHub
Browse files

Adds image-guided object detection support to OWL-ViT (#20136)

Adds image-guided object detection method to OwlViTForObjectDetection class as described in the original paper. One-shot/ image-guided object detection enables users to use a query image to search for similar objects in the input image.

Co-Authored-By: Dhruv Karan k4r4n.dhruv@gmail.com
parent 0d0d7769
...@@ -80,6 +80,8 @@ This model was contributed by [adirik](https://huggingface.co/adirik). The origi ...@@ -80,6 +80,8 @@ This model was contributed by [adirik](https://huggingface.co/adirik). The origi
[[autodoc]] OwlViTFeatureExtractor [[autodoc]] OwlViTFeatureExtractor
- __call__ - __call__
- post_process
- post_process_image_guided_detection
## OwlViTProcessor ## OwlViTProcessor
...@@ -106,3 +108,4 @@ This model was contributed by [adirik](https://huggingface.co/adirik). The origi ...@@ -106,3 +108,4 @@ This model was contributed by [adirik](https://huggingface.co/adirik). The origi
[[autodoc]] OwlViTForObjectDetection [[autodoc]] OwlViTForObjectDetection
- forward - forward
- image_guided_detection
...@@ -32,14 +32,56 @@ if is_torch_available(): ...@@ -32,14 +32,56 @@ if is_torch_available():
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
# Copied from transformers.models.detr.feature_extraction_detr.center_to_corners_format
def center_to_corners_format(x): def center_to_corners_format(x):
""" """
Converts a PyTorch tensor of bounding boxes of center format (center_x, center_y, width, height) to corners format Converts a PyTorch tensor of bounding boxes of center format (center_x, center_y, width, height) to corners format
(left, top, right, bottom). (x_0, y_0, x_1, y_1).
""" """
x_center, y_center, width, height = x.unbind(-1) center_x, center_y, width, height = x.unbind(-1)
boxes = [(x_center - 0.5 * width), (y_center - 0.5 * height), (x_center + 0.5 * width), (y_center + 0.5 * height)] b = [(center_x - 0.5 * width), (center_y - 0.5 * height), (center_x + 0.5 * width), (center_y + 0.5 * height)]
return torch.stack(boxes, dim=-1) return torch.stack(b, dim=-1)
# Copied from transformers.models.detr.modeling_detr._upcast
def _upcast(t):
# Protects from numerical overflows in multiplications by upcasting to the equivalent higher type
if t.is_floating_point():
return t if t.dtype in (torch.float32, torch.float64) else t.float()
else:
return t if t.dtype in (torch.int32, torch.int64) else t.int()
def box_area(boxes):
"""
Computes the area of a set of bounding boxes, which are specified by its (x1, y1, x2, y2) coordinates.
Args:
boxes (`torch.FloatTensor` of shape `(number_of_boxes, 4)`):
Boxes for which the area will be computed. They are expected to be in (x1, y1, x2, y2) format with `0 <= x1
< x2` and `0 <= y1 < y2`.
Returns:
`torch.FloatTensor`: a tensor containing the area for each box.
"""
boxes = _upcast(boxes)
return (boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1])
def box_iou(boxes1, boxes2):
area1 = box_area(boxes1)
area2 = box_area(boxes2)
left_top = torch.max(boxes1[:, None, :2], boxes2[:, :2]) # [N,M,2]
right_bottom = torch.min(boxes1[:, None, 2:], boxes2[:, 2:]) # [N,M,2]
width_height = (right_bottom - left_top).clamp(min=0) # [N,M,2]
inter = width_height[:, :, 0] * width_height[:, :, 1] # [N,M]
union = area1[:, None] + area2 - inter
iou = inter / union
return iou, union
class OwlViTFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionMixin): class OwlViTFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionMixin):
...@@ -56,10 +98,11 @@ class OwlViTFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionMixin ...@@ -56,10 +98,11 @@ class OwlViTFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionMixin
The size to use for resizing the image. Only has an effect if `do_resize` is set to `True`. If `size` is a The size to use for resizing the image. Only has an effect if `do_resize` is set to `True`. If `size` is a
sequence like (h, w), output size will be matched to this. If `size` is an int, then image will be resized sequence like (h, w), output size will be matched to this. If `size` is an int, then image will be resized
to (size, size). to (size, size).
resample (`int`, *optional*, defaults to `PILImageResampling.BICUBIC`): resample (`int`, *optional*, defaults to `PIL.Image.Resampling.BICUBIC`):
An optional resampling filter. This can be one of `PILImageResampling.NEAREST`, `PILImageResampling.BOX`, An optional resampling filter. This can be one of `PIL.Image.Resampling.NEAREST`,
`PILImageResampling.BILINEAR`, `PILImageResampling.HAMMING`, `PILImageResampling.BICUBIC` or `PIL.Image.Resampling.BOX`, `PIL.Image.Resampling.BILINEAR`, `PIL.Image.Resampling.HAMMING`,
`PILImageResampling.LANCZOS`. Only has an effect if `do_resize` is set to `True`. `PIL.Image.Resampling.BICUBIC` or `PIL.Image.Resampling.LANCZOS`. Only has an effect if `do_resize` is set
to `True`.
do_center_crop (`bool`, *optional*, defaults to `False`): do_center_crop (`bool`, *optional*, defaults to `False`):
Whether to crop the input at the center. If the input size is smaller than `crop_size` along any edge, the Whether to crop the input at the center. If the input size is smaller than `crop_size` along any edge, the
image is padded with 0's and then center cropped. image is padded with 0's and then center cropped.
...@@ -111,10 +154,11 @@ class OwlViTFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionMixin ...@@ -111,10 +154,11 @@ class OwlViTFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionMixin
Args: Args:
outputs ([`OwlViTObjectDetectionOutput`]): outputs ([`OwlViTObjectDetectionOutput`]):
Raw outputs of the model. Raw outputs of the model.
target_sizes (`torch.Tensor` of shape `(batch_size, 2)`): target_sizes (`torch.Tensor`, *optional*):
Tensor containing the size (h, w) of each image of the batch. For evaluation, this must be the original Tensor of shape (batch_size, 2) where each entry is the (height, width) of the corresponding image in
image size (before any data augmentation). For visualization, this should be the image size after data the batch. If set, predicted normalized bounding boxes are rescaled to the target sizes. If left to
augment, but before padding. None, predictions will not be unnormalized.
Returns: Returns:
`List[Dict]`: A list of dictionaries, each dictionary containing the scores, labels and boxes for an image `List[Dict]`: A list of dictionaries, each dictionary containing the scores, labels and boxes for an image
in the batch as predicted by the model. in the batch as predicted by the model.
...@@ -142,6 +186,82 @@ class OwlViTFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionMixin ...@@ -142,6 +186,82 @@ class OwlViTFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionMixin
return results return results
def post_process_image_guided_detection(self, outputs, threshold=0.6, nms_threshold=0.3, target_sizes=None):
"""
Converts the output of [`OwlViTForObjectDetection.image_guided_detection`] into the format expected by the COCO
api.
Args:
outputs ([`OwlViTImageGuidedObjectDetectionOutput`]):
Raw outputs of the model.
threshold (`float`, *optional*, defaults to 0.6):
Minimum confidence threshold to use to filter out predicted boxes.
nms_threshold (`float`, *optional*, defaults to 0.3):
IoU threshold for non-maximum suppression of overlapping boxes.
target_sizes (`torch.Tensor`, *optional*):
Tensor of shape (batch_size, 2) where each entry is the (height, width) of the corresponding image in
the batch. If set, predicted normalized bounding boxes are rescaled to the target sizes. If left to
None, predictions will not be unnormalized.
Returns:
`List[Dict]`: A list of dictionaries, each dictionary containing the scores, labels and boxes for an image
in the batch as predicted by the model. All labels are set to None as
`OwlViTForObjectDetection.image_guided_detection` perform one-shot object detection.
"""
logits, target_boxes = outputs.logits, outputs.target_pred_boxes
if len(logits) != len(target_sizes):
raise ValueError("Make sure that you pass in as many target sizes as the batch dimension of the logits")
if target_sizes.shape[1] != 2:
raise ValueError("Each element of target_sizes must contain the size (h, w) of each image of the batch")
probs = torch.max(logits, dim=-1)
scores = torch.sigmoid(probs.values)
# Convert to [x0, y0, x1, y1] format
target_boxes = center_to_corners_format(target_boxes)
# Apply non-maximum suppression (NMS)
if nms_threshold < 1.0:
for idx in range(target_boxes.shape[0]):
for i in torch.argsort(-scores[idx]):
if not scores[idx][i]:
continue
ious = box_iou(target_boxes[idx][i, :].unsqueeze(0), target_boxes[idx])[0][0]
ious[i] = -1.0 # Mask self-IoU.
scores[idx][ious > nms_threshold] = 0.0
# Convert from relative [0, 1] to absolute [0, height] coordinates
img_h, img_w = target_sizes.unbind(1)
scale_fct = torch.stack([img_w, img_h, img_w, img_h], dim=1)
target_boxes = target_boxes * scale_fct[:, None, :]
# Compute box display alphas based on prediction scores
results = []
alphas = torch.zeros_like(scores)
for idx in range(target_boxes.shape[0]):
# Select scores for boxes matching the current query:
query_scores = scores[idx]
if not query_scores.nonzero().numel():
continue
# Scale box alpha such that the best box for each query has alpha 1.0 and the worst box has alpha 0.1.
# All other boxes will either belong to a different query, or will not be shown.
max_score = torch.max(query_scores) + 1e-6
query_alphas = (query_scores - (max_score * 0.1)) / (max_score * 0.9)
query_alphas[query_alphas < threshold] = 0.0
query_alphas = torch.clip(query_alphas, 0.0, 1.0)
alphas[idx] = query_alphas
mask = alphas[idx] > 0
box_scores = alphas[idx][mask]
boxes = target_boxes[idx][mask]
results.append({"scores": box_scores, "labels": None, "boxes": boxes})
return results
def __call__( def __call__(
self, self,
images: Union[ images: Union[
...@@ -168,7 +288,6 @@ class OwlViTFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionMixin ...@@ -168,7 +288,6 @@ class OwlViTFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionMixin
return_tensors (`str` or [`~utils.TensorType`], *optional*, defaults to `'np'`): return_tensors (`str` or [`~utils.TensorType`], *optional*, defaults to `'np'`):
If set, will return tensors of a particular framework. Acceptable values are: If set, will return tensors of a particular framework. Acceptable values are:
- `'tf'`: Return TensorFlow `tf.constant` objects. - `'tf'`: Return TensorFlow `tf.constant` objects.
- `'pt'`: Return PyTorch `torch.Tensor` objects. - `'pt'`: Return PyTorch `torch.Tensor` objects.
- `'np'`: Return NumPy `np.ndarray` objects. - `'np'`: Return NumPy `np.ndarray` objects.
......
...@@ -43,7 +43,7 @@ class OwlViTProcessor(ProcessorMixin): ...@@ -43,7 +43,7 @@ class OwlViTProcessor(ProcessorMixin):
def __init__(self, feature_extractor, tokenizer): def __init__(self, feature_extractor, tokenizer):
super().__init__(feature_extractor, tokenizer) super().__init__(feature_extractor, tokenizer)
def __call__(self, text=None, images=None, padding="max_length", return_tensors="np", **kwargs): def __call__(self, text=None, images=None, query_images=None, padding="max_length", return_tensors="np", **kwargs):
""" """
Main method to prepare for the model one or several text(s) and image(s). This method forwards the `text` and Main method to prepare for the model one or several text(s) and image(s). This method forwards the `text` and
`kwargs` arguments to CLIPTokenizerFast's [`~CLIPTokenizerFast.__call__`] if `text` is not `None` to encode: `kwargs` arguments to CLIPTokenizerFast's [`~CLIPTokenizerFast.__call__`] if `text` is not `None` to encode:
...@@ -61,6 +61,10 @@ class OwlViTProcessor(ProcessorMixin): ...@@ -61,6 +61,10 @@ class OwlViTProcessor(ProcessorMixin):
The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch
tensor. In case of a NumPy array/PyTorch tensor, each image should be of shape (C, H, W), where C is a tensor. In case of a NumPy array/PyTorch tensor, each image should be of shape (C, H, W), where C is a
number of channels, H and W are image height and width. number of channels, H and W are image height and width.
query_images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`):
The query image to be prepared, one query image is expected per target image to be queried. Each image
can be a PIL image, NumPy array or PyTorch tensor. In case of a NumPy array/PyTorch tensor, each image
should be of shape (C, H, W), where C is a number of channels, H and W are image height and width.
return_tensors (`str` or [`~utils.TensorType`], *optional*): return_tensors (`str` or [`~utils.TensorType`], *optional*):
If set, will return tensors of a particular framework. Acceptable values are: If set, will return tensors of a particular framework. Acceptable values are:
- `'tf'`: Return TensorFlow `tf.constant` objects. - `'tf'`: Return TensorFlow `tf.constant` objects.
...@@ -76,8 +80,10 @@ class OwlViTProcessor(ProcessorMixin): ...@@ -76,8 +80,10 @@ class OwlViTProcessor(ProcessorMixin):
- **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`. - **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`.
""" """
if text is None and images is None: if text is None and query_images is None and images is None:
raise ValueError("You have to specify at least one text or image. Both cannot be none.") raise ValueError(
"You have to specify at least one text or query image or image. All three cannot be none."
)
if text is not None: if text is not None:
if isinstance(text, str) or (isinstance(text, List) and not isinstance(text[0], List)): if isinstance(text, str) or (isinstance(text, List) and not isinstance(text[0], List)):
...@@ -128,13 +134,23 @@ class OwlViTProcessor(ProcessorMixin): ...@@ -128,13 +134,23 @@ class OwlViTProcessor(ProcessorMixin):
encoding["input_ids"] = input_ids encoding["input_ids"] = input_ids
encoding["attention_mask"] = attention_mask encoding["attention_mask"] = attention_mask
if query_images is not None:
encoding = BatchEncoding()
query_pixel_values = self.feature_extractor(
query_images, return_tensors=return_tensors, **kwargs
).pixel_values
encoding["query_pixel_values"] = query_pixel_values
if images is not None: if images is not None:
image_features = self.feature_extractor(images, return_tensors=return_tensors, **kwargs) image_features = self.feature_extractor(images, return_tensors=return_tensors, **kwargs)
if text is not None and images is not None: if text is not None and images is not None:
encoding["pixel_values"] = image_features.pixel_values encoding["pixel_values"] = image_features.pixel_values
return encoding return encoding
elif text is not None: elif query_images is not None and images is not None:
encoding["pixel_values"] = image_features.pixel_values
return encoding
elif text is not None or query_images is not None:
return encoding return encoding
else: else:
return BatchEncoding(data=dict(**image_features), tensor_type=return_tensors) return BatchEncoding(data=dict(**image_features), tensor_type=return_tensors)
...@@ -146,6 +162,13 @@ class OwlViTProcessor(ProcessorMixin): ...@@ -146,6 +162,13 @@ class OwlViTProcessor(ProcessorMixin):
""" """
return self.feature_extractor.post_process(*args, **kwargs) return self.feature_extractor.post_process(*args, **kwargs)
def post_process_image_guided_detection(self, *args, **kwargs):
"""
This method forwards all its arguments to [`OwlViTFeatureExtractor.post_process_one_shot_object_detection`].
Please refer to the docstring of this method for more information.
"""
return self.feature_extractor.post_process_image_guided_detection(*args, **kwargs)
def batch_decode(self, *args, **kwargs): def batch_decode(self, *args, **kwargs):
""" """
This method forwards all its arguments to CLIPTokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please This method forwards all its arguments to CLIPTokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please
...@@ -159,9 +182,3 @@ class OwlViTProcessor(ProcessorMixin): ...@@ -159,9 +182,3 @@ class OwlViTProcessor(ProcessorMixin):
the docstring of this method for more information. the docstring of this method for more information.
""" """
return self.tokenizer.decode(*args, **kwargs) return self.tokenizer.decode(*args, **kwargs)
@property
def model_input_names(self):
tokenizer_input_names = self.tokenizer.model_input_names
feature_extractor_input_names = self.feature_extractor.model_input_names
return list(dict.fromkeys(tokenizer_input_names + feature_extractor_input_names))
...@@ -2,6 +2,8 @@ import numpy as np ...@@ -2,6 +2,8 @@ import numpy as np
import torch import torch
from torch.utils.data import Dataset, IterableDataset from torch.utils.data import Dataset, IterableDataset
from transformers.utils.generic import ModelOutput
class PipelineDataset(Dataset): class PipelineDataset(Dataset):
def __init__(self, dataset, process, params): def __init__(self, dataset, process, params):
...@@ -76,6 +78,14 @@ class PipelineIterator(IterableDataset): ...@@ -76,6 +78,14 @@ class PipelineIterator(IterableDataset):
# Batch data is assumed to be BaseModelOutput (or dict) # Batch data is assumed to be BaseModelOutput (or dict)
loader_batched = {} loader_batched = {}
for k, element in self._loader_batch_data.items(): for k, element in self._loader_batch_data.items():
if isinstance(element, ModelOutput):
# Convert ModelOutput to tuple first
element = element.to_tuple()
if isinstance(element[0], torch.Tensor):
loader_batched[k] = tuple(el[self._loader_batch_index].unsqueeze(0) for el in element)
elif isinstance(element[0], np.ndarray):
loader_batched[k] = tuple(np.expand_dims(el[self._loader_batch_index], 0) for el in element)
continue
if k in {"hidden_states", "past_key_values", "attentions"} and isinstance(element, tuple): if k in {"hidden_states", "past_key_values", "attentions"} and isinstance(element, tuple):
# Those are stored as lists of tensors so need specific unbatching. # Those are stored as lists of tensors so need specific unbatching.
if isinstance(element[0], torch.Tensor): if isinstance(element[0], torch.Tensor):
......
...@@ -19,7 +19,6 @@ import inspect ...@@ -19,7 +19,6 @@ import inspect
import os import os
import tempfile import tempfile
import unittest import unittest
from typing import Dict, List, Tuple
import numpy as np import numpy as np
...@@ -677,52 +676,6 @@ class OwlViTForObjectDetectionTest(ModelTesterMixin, unittest.TestCase): ...@@ -677,52 +676,6 @@ class OwlViTForObjectDetectionTest(ModelTesterMixin, unittest.TestCase):
self.assertTrue(models_equal) self.assertTrue(models_equal)
def test_model_outputs_equivalence(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
def set_nan_tensor_to_zero(t):
t[t != t] = 0
return t
def check_equivalence(model, tuple_inputs, dict_inputs, additional_kwargs={}):
with torch.no_grad():
tuple_output = model(**tuple_inputs, return_dict=False, **additional_kwargs)
dict_output = model(**dict_inputs, return_dict=True, **additional_kwargs).to_tuple()
def recursive_check(tuple_object, dict_object):
if isinstance(tuple_object, (List, Tuple)):
for tuple_iterable_value, dict_iterable_value in zip(tuple_object, dict_object):
recursive_check(tuple_iterable_value, dict_iterable_value)
elif isinstance(tuple_object, Dict):
for tuple_iterable_value, dict_iterable_value in zip(
tuple_object.values(), dict_object.values()
):
recursive_check(tuple_iterable_value, dict_iterable_value)
elif tuple_object is None:
return
else:
self.assertTrue(
torch.allclose(
set_nan_tensor_to_zero(tuple_object), set_nan_tensor_to_zero(dict_object), atol=1e-5
),
msg=(
"Tuple and dict output are not equal. Difference:"
f" {torch.max(torch.abs(tuple_object - dict_object))}. Tuple has `nan`:"
f" {torch.isnan(tuple_object).any()} and `inf`: {torch.isinf(tuple_object)}. Dict has"
f" `nan`: {torch.isnan(dict_object).any()} and `inf`: {torch.isinf(dict_object)}."
),
)
recursive_check(tuple_output, dict_output)
for model_class in self.all_model_classes:
model = model_class(config).to(torch_device)
model.eval()
tuple_inputs = self._prepare_for_class(inputs_dict, model_class)
dict_inputs = self._prepare_for_class(inputs_dict, model_class)
check_equivalence(model, tuple_inputs, dict_inputs)
@slow @slow
def test_model_from_pretrained(self): def test_model_from_pretrained(self):
for model_name in OWLVIT_PRETRAINED_MODEL_ARCHIVE_LIST[:1]: for model_name in OWLVIT_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
...@@ -797,3 +750,31 @@ class OwlViTModelIntegrationTest(unittest.TestCase): ...@@ -797,3 +750,31 @@ class OwlViTModelIntegrationTest(unittest.TestCase):
[[0.0691, 0.0445, 0.1373], [0.1592, 0.0456, 0.3192], [0.1632, 0.0423, 0.2478]] [[0.0691, 0.0445, 0.1373], [0.1592, 0.0456, 0.3192], [0.1632, 0.0423, 0.2478]]
).to(torch_device) ).to(torch_device)
self.assertTrue(torch.allclose(outputs.pred_boxes[0, :3, :3], expected_slice_boxes, atol=1e-4)) self.assertTrue(torch.allclose(outputs.pred_boxes[0, :3, :3], expected_slice_boxes, atol=1e-4))
@slow
def test_inference_one_shot_object_detection(self):
model_name = "google/owlvit-base-patch32"
model = OwlViTForObjectDetection.from_pretrained(model_name).to(torch_device)
processor = OwlViTProcessor.from_pretrained(model_name)
image = prepare_img()
query_image = prepare_img()
inputs = processor(
images=image,
query_images=query_image,
max_length=16,
padding="max_length",
return_tensors="pt",
).to(torch_device)
with torch.no_grad():
outputs = model.image_guided_detection(**inputs)
num_queries = int((model.config.vision_config.image_size / model.config.vision_config.patch_size) ** 2)
self.assertEqual(outputs.target_pred_boxes.shape, torch.Size((1, num_queries, 4)))
expected_slice_boxes = torch.tensor(
[[0.0691, 0.0445, 0.1373], [0.1592, 0.0456, 0.3192], [0.1632, 0.0423, 0.2478]]
).to(torch_device)
self.assertTrue(torch.allclose(outputs.target_pred_boxes[0, :3, :3], expected_slice_boxes, atol=1e-4))
...@@ -227,28 +227,32 @@ class OwlViTProcessorTest(unittest.TestCase): ...@@ -227,28 +227,32 @@ class OwlViTProcessorTest(unittest.TestCase):
self.assertListEqual(list(input_ids[0]), predicted_ids[0]) self.assertListEqual(list(input_ids[0]), predicted_ids[0])
self.assertListEqual(list(input_ids[1]), predicted_ids[1]) self.assertListEqual(list(input_ids[1]), predicted_ids[1])
def test_tokenizer_decode(self): def test_processor_case2(self):
feature_extractor = self.get_feature_extractor() feature_extractor = self.get_feature_extractor()
tokenizer = self.get_tokenizer() tokenizer = self.get_tokenizer()
processor = OwlViTProcessor(tokenizer=tokenizer, feature_extractor=feature_extractor) processor = OwlViTProcessor(tokenizer=tokenizer, feature_extractor=feature_extractor)
predicted_ids = [[1, 4, 5, 8, 1, 0, 8], [3, 4, 3, 1, 1, 8, 9]] image_input = self.prepare_image_inputs()
query_input = self.prepare_image_inputs()
decoded_processor = processor.batch_decode(predicted_ids) inputs = processor(images=image_input, query_images=query_input)
decoded_tok = tokenizer.batch_decode(predicted_ids)
self.assertListEqual(decoded_tok, decoded_processor) self.assertListEqual(list(inputs.keys()), ["query_pixel_values", "pixel_values"])
# test if it raises when no input is passed
with pytest.raises(ValueError):
processor()
def test_model_input_names(self): def test_tokenizer_decode(self):
feature_extractor = self.get_feature_extractor() feature_extractor = self.get_feature_extractor()
tokenizer = self.get_tokenizer() tokenizer = self.get_tokenizer()
processor = OwlViTProcessor(tokenizer=tokenizer, feature_extractor=feature_extractor) processor = OwlViTProcessor(tokenizer=tokenizer, feature_extractor=feature_extractor)
input_str = "lower newer" predicted_ids = [[1, 4, 5, 8, 1, 0, 8], [3, 4, 3, 1, 1, 8, 9]]
image_input = self.prepare_image_inputs()
inputs = processor(text=input_str, images=image_input) decoded_processor = processor.batch_decode(predicted_ids)
decoded_tok = tokenizer.batch_decode(predicted_ids)
self.assertListEqual(list(inputs.keys()), processor.model_input_names) self.assertListEqual(decoded_tok, decoded_processor)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment