"...git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "474bf508dfe0d46fc38585a1bb793e5ba74fddfd"
Unverified Commit d114a6b7 authored by Roland Szabo's avatar Roland Szabo Committed by GitHub
Browse files

Add timeout parameter to load_image function (#25184)



* Add timeout parameter to load_image function.

* Remove line.

* Reformat code
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Add parameter to docs.

---------
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>
parent 6d3f9c1e
...@@ -14,7 +14,7 @@ ...@@ -14,7 +14,7 @@
# limitations under the License. # limitations under the License.
import os import os
from typing import TYPE_CHECKING, Dict, Iterable, List, Tuple, Union from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Tuple, Union
import numpy as np import numpy as np
import requests import requests
...@@ -253,13 +253,15 @@ def valid_coco_panoptic_annotations(annotations: Iterable[Dict[str, Union[List, ...@@ -253,13 +253,15 @@ def valid_coco_panoptic_annotations(annotations: Iterable[Dict[str, Union[List,
return all(is_valid_annotation_coco_panoptic(ann) for ann in annotations) return all(is_valid_annotation_coco_panoptic(ann) for ann in annotations)
def load_image(image: Union[str, "PIL.Image.Image"]) -> "PIL.Image.Image": def load_image(image: Union[str, "PIL.Image.Image"], timeout: Optional[float] = None) -> "PIL.Image.Image":
""" """
Loads `image` to a PIL Image. Loads `image` to a PIL Image.
Args: Args:
image (`str` or `PIL.Image.Image`): image (`str` or `PIL.Image.Image`):
The image to convert to the PIL Image format. The image to convert to the PIL Image format.
timeout (`float`, *optional*):
The timeout value in seconds for the URL request.
Returns: Returns:
`PIL.Image.Image`: A PIL Image. `PIL.Image.Image`: A PIL Image.
...@@ -269,7 +271,7 @@ def load_image(image: Union[str, "PIL.Image.Image"]) -> "PIL.Image.Image": ...@@ -269,7 +271,7 @@ def load_image(image: Union[str, "PIL.Image.Image"]) -> "PIL.Image.Image":
if image.startswith("http://") or image.startswith("https://"): if image.startswith("http://") or image.startswith("https://"):
# We need to actually check for a real protocol, otherwise it's impossible to use a local file # We need to actually check for a real protocol, otherwise it's impossible to use a local file
# like http_huggingface_co.png # like http_huggingface_co.png
image = PIL.Image.open(requests.get(image, stream=True).raw) image = PIL.Image.open(requests.get(image, stream=True, timeout=timeout).raw)
elif os.path.isfile(image): elif os.path.isfile(image):
image = PIL.Image.open(image) image = PIL.Image.open(image)
else: else:
......
...@@ -68,6 +68,9 @@ class DepthEstimationPipeline(Pipeline): ...@@ -68,6 +68,9 @@ class DepthEstimationPipeline(Pipeline):
top_k (`int`, *optional*, defaults to 5): top_k (`int`, *optional*, defaults to 5):
The number of top labels that will be returned by the pipeline. If the provided number is higher than The number of top labels that will be returned by the pipeline. If the provided number is higher than
the number of labels available in the model configuration, it will default to the number of labels. the number of labels available in the model configuration, it will default to the number of labels.
timeout (`float`, *optional*, defaults to None):
The maximum time in seconds to wait for fetching images from the web. If None, no timeout is set and
the call may block forever.
Return: Return:
A dictionary or a list of dictionaries containing result. If the input is a single image, will return a A dictionary or a list of dictionaries containing result. If the input is a single image, will return a
...@@ -81,11 +84,14 @@ class DepthEstimationPipeline(Pipeline): ...@@ -81,11 +84,14 @@ class DepthEstimationPipeline(Pipeline):
""" """
return super().__call__(images, **kwargs) return super().__call__(images, **kwargs)
def _sanitize_parameters(self, **kwargs): def _sanitize_parameters(self, timeout=None, **kwargs):
return {}, {}, {} preprocess_params = {}
if timeout is not None:
preprocess_params["timeout"] = timeout
return preprocess_params, {}, {}
def preprocess(self, image): def preprocess(self, image, timeout=None):
image = load_image(image) image = load_image(image, timeout)
self.image_size = image.size self.image_size = image.size
model_inputs = self.image_processor(images=image, return_tensors=self.framework) model_inputs = self.image_processor(images=image, return_tensors=self.framework)
return model_inputs return model_inputs
......
...@@ -159,6 +159,7 @@ class DocumentQuestionAnsweringPipeline(ChunkPipeline): ...@@ -159,6 +159,7 @@ class DocumentQuestionAnsweringPipeline(ChunkPipeline):
max_seq_len=None, max_seq_len=None,
top_k=None, top_k=None,
handle_impossible_answer=None, handle_impossible_answer=None,
timeout=None,
**kwargs, **kwargs,
): ):
preprocess_params, postprocess_params = {}, {} preprocess_params, postprocess_params = {}, {}
...@@ -174,6 +175,8 @@ class DocumentQuestionAnsweringPipeline(ChunkPipeline): ...@@ -174,6 +175,8 @@ class DocumentQuestionAnsweringPipeline(ChunkPipeline):
preprocess_params["lang"] = lang preprocess_params["lang"] = lang
if tesseract_config is not None: if tesseract_config is not None:
preprocess_params["tesseract_config"] = tesseract_config preprocess_params["tesseract_config"] = tesseract_config
if timeout is not None:
preprocess_params["timeout"] = timeout
if top_k is not None: if top_k is not None:
if top_k < 1: if top_k < 1:
...@@ -244,6 +247,9 @@ class DocumentQuestionAnsweringPipeline(ChunkPipeline): ...@@ -244,6 +247,9 @@ class DocumentQuestionAnsweringPipeline(ChunkPipeline):
Language to use while running OCR. Defaults to english. Language to use while running OCR. Defaults to english.
tesseract_config (`str`, *optional*): tesseract_config (`str`, *optional*):
Additional flags to pass to tesseract while running OCR. Additional flags to pass to tesseract while running OCR.
timeout (`float`, *optional*, defaults to None):
The maximum time in seconds to wait for fetching images from the web. If None, no timeout is set and
the call may block forever.
Return: Return:
A `dict` or a list of `dict`: Each result comes as a dictionary with the following keys: A `dict` or a list of `dict`: Each result comes as a dictionary with the following keys:
...@@ -273,6 +279,7 @@ class DocumentQuestionAnsweringPipeline(ChunkPipeline): ...@@ -273,6 +279,7 @@ class DocumentQuestionAnsweringPipeline(ChunkPipeline):
word_boxes: Tuple[str, List[float]] = None, word_boxes: Tuple[str, List[float]] = None,
lang=None, lang=None,
tesseract_config="", tesseract_config="",
timeout=None,
): ):
# NOTE: This code mirrors the code in question answering and will be implemented in a follow up PR # NOTE: This code mirrors the code in question answering and will be implemented in a follow up PR
# to support documents with enough tokens that overflow the model's window # to support documents with enough tokens that overflow the model's window
...@@ -285,7 +292,7 @@ class DocumentQuestionAnsweringPipeline(ChunkPipeline): ...@@ -285,7 +292,7 @@ class DocumentQuestionAnsweringPipeline(ChunkPipeline):
image = None image = None
image_features = {} image_features = {}
if input.get("image", None) is not None: if input.get("image", None) is not None:
image = load_image(input["image"]) image = load_image(input["image"], timeout=timeout)
if self.image_processor is not None: if self.image_processor is not None:
image_features.update(self.image_processor(images=image, return_tensors=self.framework)) image_features.update(self.image_processor(images=image, return_tensors=self.framework))
elif self.feature_extractor is not None: elif self.feature_extractor is not None:
......
...@@ -62,11 +62,14 @@ class ImageClassificationPipeline(Pipeline): ...@@ -62,11 +62,14 @@ class ImageClassificationPipeline(Pipeline):
else MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES else MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES
) )
def _sanitize_parameters(self, top_k=None): def _sanitize_parameters(self, top_k=None, timeout=None):
preprocess_params = {}
if timeout is not None:
preprocess_params["timeout"] = timeout
postprocess_params = {} postprocess_params = {}
if top_k is not None: if top_k is not None:
postprocess_params["top_k"] = top_k postprocess_params["top_k"] = top_k
return {}, {}, postprocess_params return preprocess_params, {}, postprocess_params
def __call__(self, images: Union[str, List[str], "Image.Image", List["Image.Image"]], **kwargs): def __call__(self, images: Union[str, List[str], "Image.Image", List["Image.Image"]], **kwargs):
""" """
...@@ -86,6 +89,9 @@ class ImageClassificationPipeline(Pipeline): ...@@ -86,6 +89,9 @@ class ImageClassificationPipeline(Pipeline):
top_k (`int`, *optional*, defaults to 5): top_k (`int`, *optional*, defaults to 5):
The number of top labels that will be returned by the pipeline. If the provided number is higher than The number of top labels that will be returned by the pipeline. If the provided number is higher than
the number of labels available in the model configuration, it will default to the number of labels. the number of labels available in the model configuration, it will default to the number of labels.
timeout (`float`, *optional*, defaults to None):
The maximum time in seconds to wait for fetching images from the web. If None, no timeout is set and
the call may block forever.
Return: Return:
A dictionary or a list of dictionaries containing result. If the input is a single image, will return a A dictionary or a list of dictionaries containing result. If the input is a single image, will return a
...@@ -99,8 +105,8 @@ class ImageClassificationPipeline(Pipeline): ...@@ -99,8 +105,8 @@ class ImageClassificationPipeline(Pipeline):
""" """
return super().__call__(images, **kwargs) return super().__call__(images, **kwargs)
def preprocess(self, image): def preprocess(self, image, timeout=None):
image = load_image(image) image = load_image(image, timeout=timeout)
model_inputs = self.image_processor(images=image, return_tensors=self.framework) model_inputs = self.image_processor(images=image, return_tensors=self.framework)
return model_inputs return model_inputs
......
...@@ -89,6 +89,8 @@ class ImageSegmentationPipeline(Pipeline): ...@@ -89,6 +89,8 @@ class ImageSegmentationPipeline(Pipeline):
postprocess_kwargs["mask_threshold"] = kwargs["mask_threshold"] postprocess_kwargs["mask_threshold"] = kwargs["mask_threshold"]
if "overlap_mask_area_threshold" in kwargs: if "overlap_mask_area_threshold" in kwargs:
postprocess_kwargs["overlap_mask_area_threshold"] = kwargs["overlap_mask_area_threshold"] postprocess_kwargs["overlap_mask_area_threshold"] = kwargs["overlap_mask_area_threshold"]
if "timeout" in kwargs:
preprocess_kwargs["timeout"] = kwargs["timeout"]
return preprocess_kwargs, {}, postprocess_kwargs return preprocess_kwargs, {}, postprocess_kwargs
...@@ -116,6 +118,9 @@ class ImageSegmentationPipeline(Pipeline): ...@@ -116,6 +118,9 @@ class ImageSegmentationPipeline(Pipeline):
Threshold to use when turning the predicted masks into binary values. Threshold to use when turning the predicted masks into binary values.
overlap_mask_area_threshold (`float`, *optional*, defaults to 0.5): overlap_mask_area_threshold (`float`, *optional*, defaults to 0.5):
Mask overlap threshold to eliminate small, disconnected segments. Mask overlap threshold to eliminate small, disconnected segments.
timeout (`float`, *optional*, defaults to None):
The maximum time in seconds to wait for fetching images from the web. If None, no timeout is set and
the call may block forever.
Return: Return:
A dictionary or a list of dictionaries containing the result. If the input is a single image, will return a A dictionary or a list of dictionaries containing the result. If the input is a single image, will return a
...@@ -133,8 +138,8 @@ class ImageSegmentationPipeline(Pipeline): ...@@ -133,8 +138,8 @@ class ImageSegmentationPipeline(Pipeline):
""" """
return super().__call__(images, **kwargs) return super().__call__(images, **kwargs)
def preprocess(self, image, subtask=None): def preprocess(self, image, subtask=None, timeout=None):
image = load_image(image) image = load_image(image, timeout=timeout)
target_size = [(image.height, image.width)] target_size = [(image.height, image.width)]
if self.model.config.__class__.__name__ == "OneFormerConfig": if self.model.config.__class__.__name__ == "OneFormerConfig":
if subtask is None: if subtask is None:
......
...@@ -58,12 +58,14 @@ class ImageToTextPipeline(Pipeline): ...@@ -58,12 +58,14 @@ class ImageToTextPipeline(Pipeline):
TF_MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES if self.framework == "tf" else MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES TF_MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES if self.framework == "tf" else MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES
) )
def _sanitize_parameters(self, max_new_tokens=None, generate_kwargs=None, prompt=None): def _sanitize_parameters(self, max_new_tokens=None, generate_kwargs=None, prompt=None, timeout=None):
forward_kwargs = {} forward_kwargs = {}
preprocess_params = {} preprocess_params = {}
if prompt is not None: if prompt is not None:
preprocess_params["prompt"] = prompt preprocess_params["prompt"] = prompt
if timeout is not None:
preprocess_params["timeout"] = timeout
if generate_kwargs is not None: if generate_kwargs is not None:
forward_kwargs["generate_kwargs"] = generate_kwargs forward_kwargs["generate_kwargs"] = generate_kwargs
...@@ -97,6 +99,9 @@ class ImageToTextPipeline(Pipeline): ...@@ -97,6 +99,9 @@ class ImageToTextPipeline(Pipeline):
generate_kwargs (`Dict`, *optional*): generate_kwargs (`Dict`, *optional*):
Pass it to send all of these arguments directly to `generate` allowing full control of this function. Pass it to send all of these arguments directly to `generate` allowing full control of this function.
timeout (`float`, *optional*, defaults to None):
The maximum time in seconds to wait for fetching images from the web. If None, no timeout is set and
the call may block forever.
Return: Return:
A list or a list of list of `dict`: Each result comes as a dictionary with the following key: A list or a list of list of `dict`: Each result comes as a dictionary with the following key:
...@@ -105,8 +110,8 @@ class ImageToTextPipeline(Pipeline): ...@@ -105,8 +110,8 @@ class ImageToTextPipeline(Pipeline):
""" """
return super().__call__(images, **kwargs) return super().__call__(images, **kwargs)
def preprocess(self, image, prompt=None): def preprocess(self, image, prompt=None, timeout=None):
image = load_image(image) image = load_image(image, timeout=timeout)
if prompt is not None: if prompt is not None:
if not isinstance(prompt, str): if not isinstance(prompt, str):
......
...@@ -113,6 +113,8 @@ class MaskGenerationPipeline(ChunkPipeline): ...@@ -113,6 +113,8 @@ class MaskGenerationPipeline(ChunkPipeline):
preprocess_kwargs["crop_overlap_ratio"] = kwargs["crop_overlap_ratio"] preprocess_kwargs["crop_overlap_ratio"] = kwargs["crop_overlap_ratio"]
if "crop_n_points_downscale_factor" in kwargs: if "crop_n_points_downscale_factor" in kwargs:
preprocess_kwargs["crop_n_points_downscale_factor"] = kwargs["crop_n_points_downscale_factor"] preprocess_kwargs["crop_n_points_downscale_factor"] = kwargs["crop_n_points_downscale_factor"]
if "timeout" in kwargs:
preprocess_kwargs["timeout"] = kwargs["timeout"]
# postprocess args # postprocess args
if "pred_iou_thresh" in kwargs: if "pred_iou_thresh" in kwargs:
forward_params["pred_iou_thresh"] = kwargs["pred_iou_thresh"] forward_params["pred_iou_thresh"] = kwargs["pred_iou_thresh"]
...@@ -156,6 +158,9 @@ class MaskGenerationPipeline(ChunkPipeline): ...@@ -156,6 +158,9 @@ class MaskGenerationPipeline(ChunkPipeline):
the image length. Later layers with more crops scale down this overlap. the image length. Later layers with more crops scale down this overlap.
crop_n_points_downscale_factor (`int`, *optional*, defaults to `1`): crop_n_points_downscale_factor (`int`, *optional*, defaults to `1`):
The number of points-per-side sampled in layer n is scaled down by crop_n_points_downscale_factor**n. The number of points-per-side sampled in layer n is scaled down by crop_n_points_downscale_factor**n.
timeout (`float`, *optional*, defaults to None):
The maximum time in seconds to wait for fetching images from the web. If None, no timeout is set and
the call may block forever.
Return: Return:
`Dict`: A dictionary with the following keys: `Dict`: A dictionary with the following keys:
...@@ -175,8 +180,9 @@ class MaskGenerationPipeline(ChunkPipeline): ...@@ -175,8 +180,9 @@ class MaskGenerationPipeline(ChunkPipeline):
crop_overlap_ratio: float = 512 / 1500, crop_overlap_ratio: float = 512 / 1500,
points_per_crop: Optional[int] = 32, points_per_crop: Optional[int] = 32,
crop_n_points_downscale_factor: Optional[int] = 1, crop_n_points_downscale_factor: Optional[int] = 1,
timeout: Optional[float] = None,
): ):
image = load_image(image) image = load_image(image, timeout=timeout)
target_size = self.image_processor.size["longest_edge"] target_size = self.image_processor.size["longest_edge"]
crop_boxes, grid_points, cropped_images, input_labels = self.image_processor.generate_crop_boxes( crop_boxes, grid_points, cropped_images, input_labels = self.image_processor.generate_crop_boxes(
image, target_size, crops_n_layers, crop_overlap_ratio, points_per_crop, crop_n_points_downscale_factor image, target_size, crops_n_layers, crop_overlap_ratio, points_per_crop, crop_n_points_downscale_factor
......
...@@ -61,10 +61,13 @@ class ObjectDetectionPipeline(Pipeline): ...@@ -61,10 +61,13 @@ class ObjectDetectionPipeline(Pipeline):
self.check_model_type(mapping) self.check_model_type(mapping)
def _sanitize_parameters(self, **kwargs): def _sanitize_parameters(self, **kwargs):
preprocess_params = {}
if "timeout" in kwargs:
preprocess_params["timeout"] = kwargs["timeout"]
postprocess_kwargs = {} postprocess_kwargs = {}
if "threshold" in kwargs: if "threshold" in kwargs:
postprocess_kwargs["threshold"] = kwargs["threshold"] postprocess_kwargs["threshold"] = kwargs["threshold"]
return {}, {}, postprocess_kwargs return preprocess_params, {}, postprocess_kwargs
def __call__(self, *args, **kwargs) -> Union[Predictions, List[Prediction]]: def __call__(self, *args, **kwargs) -> Union[Predictions, List[Prediction]]:
""" """
...@@ -82,6 +85,9 @@ class ObjectDetectionPipeline(Pipeline): ...@@ -82,6 +85,9 @@ class ObjectDetectionPipeline(Pipeline):
same format: all as HTTP(S) links, all as local paths, or all as PIL images. same format: all as HTTP(S) links, all as local paths, or all as PIL images.
threshold (`float`, *optional*, defaults to 0.9): threshold (`float`, *optional*, defaults to 0.9):
The probability necessary to make a prediction. The probability necessary to make a prediction.
timeout (`float`, *optional*, defaults to None):
The maximum time in seconds to wait for fetching images from the web. If None, no timeout is set and
the call may block forever.
Return: Return:
A list of dictionaries or a list of list of dictionaries containing the result. If the input is a single A list of dictionaries or a list of list of dictionaries containing the result. If the input is a single
...@@ -97,8 +103,8 @@ class ObjectDetectionPipeline(Pipeline): ...@@ -97,8 +103,8 @@ class ObjectDetectionPipeline(Pipeline):
return super().__call__(*args, **kwargs) return super().__call__(*args, **kwargs)
def preprocess(self, image): def preprocess(self, image, timeout=None):
image = load_image(image) image = load_image(image, timeout=timeout)
target_size = torch.IntTensor([[image.height, image.width]]) target_size = torch.IntTensor([[image.height, image.width]])
inputs = self.image_processor(images=[image], return_tensors="pt") inputs = self.image_processor(images=[image], return_tensors="pt")
if self.tokenizer is not None: if self.tokenizer is not None:
......
...@@ -55,12 +55,14 @@ class VisualQuestionAnsweringPipeline(Pipeline): ...@@ -55,12 +55,14 @@ class VisualQuestionAnsweringPipeline(Pipeline):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
self.check_model_type(MODEL_FOR_VISUAL_QUESTION_ANSWERING_MAPPING_NAMES) self.check_model_type(MODEL_FOR_VISUAL_QUESTION_ANSWERING_MAPPING_NAMES)
def _sanitize_parameters(self, top_k=None, padding=None, truncation=None, **kwargs): def _sanitize_parameters(self, top_k=None, padding=None, truncation=None, timeout=None, **kwargs):
preprocess_params, postprocess_params = {}, {} preprocess_params, postprocess_params = {}, {}
if padding is not None: if padding is not None:
preprocess_params["padding"] = padding preprocess_params["padding"] = padding
if truncation is not None: if truncation is not None:
preprocess_params["truncation"] = truncation preprocess_params["truncation"] = truncation
if timeout is not None:
preprocess_params["timeout"] = timeout
if top_k is not None: if top_k is not None:
postprocess_params["top_k"] = top_k postprocess_params["top_k"] = top_k
return preprocess_params, {}, postprocess_params return preprocess_params, {}, postprocess_params
...@@ -90,6 +92,9 @@ class VisualQuestionAnsweringPipeline(Pipeline): ...@@ -90,6 +92,9 @@ class VisualQuestionAnsweringPipeline(Pipeline):
top_k (`int`, *optional*, defaults to 5): top_k (`int`, *optional*, defaults to 5):
The number of top labels that will be returned by the pipeline. If the provided number is higher than The number of top labels that will be returned by the pipeline. If the provided number is higher than
the number of labels available in the model configuration, it will default to the number of labels. the number of labels available in the model configuration, it will default to the number of labels.
timeout (`float`, *optional*, defaults to None):
The maximum time in seconds to wait for fetching images from the web. If None, no timeout is set and
the call may block forever.
Return: Return:
A dictionary or a list of dictionaries containing the result. The dictionaries contain the following keys: A dictionary or a list of dictionaries containing the result. The dictionaries contain the following keys:
...@@ -109,8 +114,8 @@ class VisualQuestionAnsweringPipeline(Pipeline): ...@@ -109,8 +114,8 @@ class VisualQuestionAnsweringPipeline(Pipeline):
results = super().__call__(inputs, **kwargs) results = super().__call__(inputs, **kwargs)
return results return results
def preprocess(self, inputs, padding=False, truncation=False): def preprocess(self, inputs, padding=False, truncation=False, timeout=None):
image = load_image(inputs["image"]) image = load_image(inputs["image"], timeout=timeout)
model_inputs = self.tokenizer( model_inputs = self.tokenizer(
inputs["question"], return_tensors=self.framework, padding=padding, truncation=truncation inputs["question"], return_tensors=self.framework, padding=padding, truncation=truncation
) )
......
...@@ -91,6 +91,10 @@ class ZeroShotImageClassificationPipeline(Pipeline): ...@@ -91,6 +91,10 @@ class ZeroShotImageClassificationPipeline(Pipeline):
replacing the placeholder with the candidate_labels. Then likelihood is estimated by using replacing the placeholder with the candidate_labels. Then likelihood is estimated by using
logits_per_image logits_per_image
timeout (`float`, *optional*, defaults to None):
The maximum time in seconds to wait for fetching images from the web. If None, no timeout is set and
the call may block forever.
Return: Return:
A list of dictionaries containing result, one dictionary per proposed label. The dictionaries contain the A list of dictionaries containing result, one dictionary per proposed label. The dictionaries contain the
following keys: following keys:
...@@ -104,13 +108,15 @@ class ZeroShotImageClassificationPipeline(Pipeline): ...@@ -104,13 +108,15 @@ class ZeroShotImageClassificationPipeline(Pipeline):
preprocess_params = {} preprocess_params = {}
if "candidate_labels" in kwargs: if "candidate_labels" in kwargs:
preprocess_params["candidate_labels"] = kwargs["candidate_labels"] preprocess_params["candidate_labels"] = kwargs["candidate_labels"]
if "timeout" in kwargs:
preprocess_params["timeout"] = kwargs["timeout"]
if "hypothesis_template" in kwargs: if "hypothesis_template" in kwargs:
preprocess_params["hypothesis_template"] = kwargs["hypothesis_template"] preprocess_params["hypothesis_template"] = kwargs["hypothesis_template"]
return preprocess_params, {}, {} return preprocess_params, {}, {}
def preprocess(self, image, candidate_labels=None, hypothesis_template="This is a photo of {}."): def preprocess(self, image, candidate_labels=None, hypothesis_template="This is a photo of {}.", timeout=None):
image = load_image(image) image = load_image(image, timeout=timeout)
inputs = self.image_processor(images=[image], return_tensors=self.framework) inputs = self.image_processor(images=[image], return_tensors=self.framework)
inputs["candidate_labels"] = candidate_labels inputs["candidate_labels"] = candidate_labels
sequences = [hypothesis_template.format(x) for x in candidate_labels] sequences = [hypothesis_template.format(x) for x in candidate_labels]
......
...@@ -111,6 +111,10 @@ class ZeroShotObjectDetectionPipeline(ChunkPipeline): ...@@ -111,6 +111,10 @@ class ZeroShotObjectDetectionPipeline(ChunkPipeline):
The number of top predictions that will be returned by the pipeline. If the provided number is `None` The number of top predictions that will be returned by the pipeline. If the provided number is `None`
or higher than the number of predictions available, it will default to the number of predictions. or higher than the number of predictions available, it will default to the number of predictions.
timeout (`float`, *optional*, defaults to None):
The maximum time in seconds to wait for fetching images from the web. If None, no timeout is set and
the call may block forever.
Return: Return:
A list of lists containing prediction results, one list per input image. Each list contains dictionaries A list of lists containing prediction results, one list per input image. Each list contains dictionaries
...@@ -132,15 +136,18 @@ class ZeroShotObjectDetectionPipeline(ChunkPipeline): ...@@ -132,15 +136,18 @@ class ZeroShotObjectDetectionPipeline(ChunkPipeline):
return results return results
def _sanitize_parameters(self, **kwargs): def _sanitize_parameters(self, **kwargs):
preprocess_params = {}
if "timeout" in kwargs:
preprocess_params["timeout"] = kwargs["timeout"]
postprocess_params = {} postprocess_params = {}
if "threshold" in kwargs: if "threshold" in kwargs:
postprocess_params["threshold"] = kwargs["threshold"] postprocess_params["threshold"] = kwargs["threshold"]
if "top_k" in kwargs: if "top_k" in kwargs:
postprocess_params["top_k"] = kwargs["top_k"] postprocess_params["top_k"] = kwargs["top_k"]
return {}, {}, postprocess_params return preprocess_params, {}, postprocess_params
def preprocess(self, inputs): def preprocess(self, inputs, timeout=None):
image = load_image(inputs["image"]) image = load_image(inputs["image"], timeout=timeout)
candidate_labels = inputs["candidate_labels"] candidate_labels = inputs["candidate_labels"]
if isinstance(candidate_labels, str): if isinstance(candidate_labels, str):
candidate_labels = candidate_labels.split(",") candidate_labels = candidate_labels.split(",")
......
...@@ -18,7 +18,9 @@ import unittest ...@@ -18,7 +18,9 @@ import unittest
import datasets import datasets
import numpy as np import numpy as np
import pytest import pytest
from requests import ReadTimeout
from tests.pipelines.test_pipelines_document_question_answering import INVOICE_URL
from transformers import is_torch_available, is_vision_available from transformers import is_torch_available, is_vision_available
from transformers.image_utils import ChannelDimension, get_channel_dimension_axis, make_list_of_images from transformers.image_utils import ChannelDimension, get_channel_dimension_axis, make_list_of_images
from transformers.testing_utils import require_torch, require_vision from transformers.testing_utils import require_torch, require_vision
...@@ -478,6 +480,16 @@ class ImageFeatureExtractionTester(unittest.TestCase): ...@@ -478,6 +480,16 @@ class ImageFeatureExtractionTester(unittest.TestCase):
@require_vision @require_vision
class LoadImageTester(unittest.TestCase): class LoadImageTester(unittest.TestCase):
def test_load_img_url(self):
img = load_image(INVOICE_URL)
img_arr = np.array(img)
self.assertEqual(img_arr.shape, (1061, 750, 3))
def test_load_img_url_timeout(self):
with self.assertRaises(ReadTimeout):
load_image(INVOICE_URL, timeout=0.001)
def test_load_img_local(self): def test_load_img_local(self):
img = load_image("./tests/fixtures/tests_samples/COCO/000000039769.png") img = load_image("./tests/fixtures/tests_samples/COCO/000000039769.png")
img_arr = np.array(img) img_arr = np.array(img)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment