Unverified Commit ea9caf7a authored by Rafael Padilla's avatar Rafael Padilla Committed by GitHub
Browse files

Update warning messages reffering to post_process_object_detection (#24649)

* including the threshold alert in warning messages.

* Updating doc owlvit.md including post_process_object_detection function with threshold.

* fix
parent f3e96235
...@@ -50,17 +50,13 @@ OWL-ViT is a zero-shot text-conditioned object detection model. OWL-ViT uses [CL ...@@ -50,17 +50,13 @@ OWL-ViT is a zero-shot text-conditioned object detection model. OWL-ViT uses [CL
>>> # Target image sizes (height, width) to rescale box predictions [batch_size, 2] >>> # Target image sizes (height, width) to rescale box predictions [batch_size, 2]
>>> target_sizes = torch.Tensor([image.size[::-1]]) >>> target_sizes = torch.Tensor([image.size[::-1]])
>>> # Convert outputs (bounding boxes and class logits) to COCO API >>> # Convert outputs (bounding boxes and class logits) to COCO API
>>> results = processor.post_process(outputs=outputs, target_sizes=target_sizes) >>> results = processor.post_process_object_detection(outputs=outputs, target_sizes=target_sizes, threshold=0.1)
>>> i = 0 # Retrieve predictions for the first image for the corresponding text queries >>> i = 0 # Retrieve predictions for the first image for the corresponding text queries
>>> text = texts[i] >>> text = texts[i]
>>> boxes, scores, labels = results[i]["boxes"], results[i]["scores"], results[i]["labels"] >>> boxes, scores, labels = results[i]["boxes"], results[i]["scores"], results[i]["labels"]
>>> score_threshold = 0.1
>>> for box, score, label in zip(boxes, scores, labels): >>> for box, score, label in zip(boxes, scores, labels):
... box = [round(i, 2) for i in box.tolist()] ... box = [round(i, 2) for i in box.tolist()]
... if score >= score_threshold: ... print(f"Detected {text[label]} with confidence {round(score.item(), 3)} at location {box}")
... print(f"Detected {text[label]} with confidence {round(score.item(), 3)} at location {box}")
Detected a photo of a cat with confidence 0.707 at location [324.97, 20.44, 640.58, 373.29] Detected a photo of a cat with confidence 0.707 at location [324.97, 20.44, 640.58, 373.29]
Detected a photo of a cat with confidence 0.717 at location [1.46, 55.26, 315.55, 472.17] Detected a photo of a cat with confidence 0.717 at location [1.46, 55.26, 315.55, 472.17]
``` ```
......
...@@ -1250,7 +1250,7 @@ class ConditionalDetrImageProcessor(BaseImageProcessor): ...@@ -1250,7 +1250,7 @@ class ConditionalDetrImageProcessor(BaseImageProcessor):
""" """
logging.warning_once( logging.warning_once(
"`post_process` is deprecated and will be removed in v5 of Transformers, please use" "`post_process` is deprecated and will be removed in v5 of Transformers, please use"
" `post_process_object_detection`", " `post_process_object_detection` instead, with `threshold=0.` for equivalent results.",
) )
out_logits, out_bbox = outputs.logits, outputs.pred_boxes out_logits, out_bbox = outputs.logits, outputs.pred_boxes
......
...@@ -1248,7 +1248,7 @@ class DeformableDetrImageProcessor(BaseImageProcessor): ...@@ -1248,7 +1248,7 @@ class DeformableDetrImageProcessor(BaseImageProcessor):
""" """
logger.warning_once( logger.warning_once(
"`post_process` is deprecated and will be removed in v5 of Transformers, please use" "`post_process` is deprecated and will be removed in v5 of Transformers, please use"
" `post_process_object_detection`.", " `post_process_object_detection` instead, with `threshold=0.` for equivalent results.",
) )
out_logits, out_bbox = outputs.logits, outputs.pred_boxes out_logits, out_bbox = outputs.logits, outputs.pred_boxes
......
...@@ -1219,7 +1219,7 @@ class DetrImageProcessor(BaseImageProcessor): ...@@ -1219,7 +1219,7 @@ class DetrImageProcessor(BaseImageProcessor):
""" """
logger.warning_once( logger.warning_once(
"`post_process` is deprecated and will be removed in v5 of Transformers, please use" "`post_process` is deprecated and will be removed in v5 of Transformers, please use"
" `post_process_object_detection`", " `post_process_object_detection` instead, with `threshold=0.` for equivalent results.",
) )
out_logits, out_bbox = outputs.logits, outputs.pred_boxes out_logits, out_bbox = outputs.logits, outputs.pred_boxes
......
...@@ -354,7 +354,7 @@ class OwlViTImageProcessor(BaseImageProcessor): ...@@ -354,7 +354,7 @@ class OwlViTImageProcessor(BaseImageProcessor):
# TODO: (amy) add support for other frameworks # TODO: (amy) add support for other frameworks
warnings.warn( warnings.warn(
"`post_process` is deprecated and will be removed in v5 of Transformers, please use" "`post_process` is deprecated and will be removed in v5 of Transformers, please use"
" `post_process_object_detection`", " `post_process_object_detection` instead, with `threshold=0.` for equivalent results.",
FutureWarning, FutureWarning,
) )
......
...@@ -1151,7 +1151,7 @@ class YolosImageProcessor(BaseImageProcessor): ...@@ -1151,7 +1151,7 @@ class YolosImageProcessor(BaseImageProcessor):
""" """
logger.warning_once( logger.warning_once(
"`post_process` is deprecated and will be removed in v5 of Transformers, please use" "`post_process` is deprecated and will be removed in v5 of Transformers, please use"
" `post_process_object_detection`", " `post_process_object_detection` instead, with `threshold=0.` for equivalent results.",
) )
out_logits, out_bbox = outputs.logits, outputs.pred_boxes out_logits, out_bbox = outputs.logits, outputs.pred_boxes
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment