"docs/source/vscode:/vscode.git/clone" did not exist on "c4cc894086ba86fefbd265f9a80fc8220d2ee182"
Unverified Commit 8608bf20 authored by Rafael Padilla's avatar Rafael Padilla Committed by GitHub
Browse files

🚨🚨🚨 changing default threshold and applying threshold before the rescale (#25608)

changing position of score threshold and its default value
parent 2df24228
...@@ -503,7 +503,7 @@ class OwlViTImageProcessor(BaseImageProcessor): ...@@ -503,7 +503,7 @@ class OwlViTImageProcessor(BaseImageProcessor):
return results return results
# TODO: (Amy) Make compatible with other frameworks # TODO: (Amy) Make compatible with other frameworks
def post_process_image_guided_detection(self, outputs, threshold=0.6, nms_threshold=0.3, target_sizes=None): def post_process_image_guided_detection(self, outputs, threshold=0.0, nms_threshold=0.3, target_sizes=None):
""" """
Converts the output of [`OwlViTForObjectDetection.image_guided_detection`] into the format expected by the COCO Converts the output of [`OwlViTForObjectDetection.image_guided_detection`] into the format expected by the COCO
api. api.
...@@ -511,7 +511,7 @@ class OwlViTImageProcessor(BaseImageProcessor): ...@@ -511,7 +511,7 @@ class OwlViTImageProcessor(BaseImageProcessor):
Args: Args:
outputs ([`OwlViTImageGuidedObjectDetectionOutput`]): outputs ([`OwlViTImageGuidedObjectDetectionOutput`]):
Raw outputs of the model. Raw outputs of the model.
threshold (`float`, *optional*, defaults to 0.6): threshold (`float`, *optional*, defaults to 0.0):
Minimum confidence threshold to use to filter out predicted boxes. Minimum confidence threshold to use to filter out predicted boxes.
nms_threshold (`float`, *optional*, defaults to 0.3): nms_threshold (`float`, *optional*, defaults to 0.3):
IoU threshold for non-maximum suppression of overlapping boxes. IoU threshold for non-maximum suppression of overlapping boxes.
...@@ -564,11 +564,13 @@ class OwlViTImageProcessor(BaseImageProcessor): ...@@ -564,11 +564,13 @@ class OwlViTImageProcessor(BaseImageProcessor):
if not query_scores.nonzero().numel(): if not query_scores.nonzero().numel():
continue continue
# Apply threshold on scores before scaling
query_scores[query_scores < threshold] = 0.0
# Scale box alpha such that the best box for each query has alpha 1.0 and the worst box has alpha 0.1. # Scale box alpha such that the best box for each query has alpha 1.0 and the worst box has alpha 0.1.
# All other boxes will either belong to a different query, or will not be shown. # All other boxes will either belong to a different query, or will not be shown.
max_score = torch.max(query_scores) + 1e-6 max_score = torch.max(query_scores) + 1e-6
query_alphas = (query_scores - (max_score * 0.1)) / (max_score * 0.9) query_alphas = (query_scores - (max_score * 0.1)) / (max_score * 0.9)
query_alphas[query_alphas < threshold] = 0.0
query_alphas = torch.clip(query_alphas, 0.0, 1.0) query_alphas = torch.clip(query_alphas, 0.0, 1.0)
alphas[idx] = query_alphas alphas[idx] = query_alphas
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment