"examples/vscode:/vscode.git/clone" did not exist on "628b59e51d52d1c9885d213bc643442233451818"
Unverified Commit b92abfa6 authored by Alexander Brokking's avatar Alexander Brokking Committed by GitHub
Browse files

Add `top_k` argument to post-process of conditional/deformable-DETR (#22787)

* update min k_value of conditional detr post-processing

* feat: add top_k arg to post processing of deformable and conditional detr

* refactor: revert changes to deprecated methods

* refactor: move prob reshape to improve code clarity and reduce repetition
parent f82ee109
......@@ -1328,7 +1328,7 @@ class ConditionalDetrImageProcessor(BaseImageProcessor):
# Copied from transformers.models.deformable_detr.image_processing_deformable_detr.DeformableDetrImageProcessor.post_process_object_detection with DeformableDetr->ConditionalDetr
def post_process_object_detection(
self, outputs, threshold: float = 0.5, target_sizes: Union[TensorType, List[Tuple]] = None
self, outputs, threshold: float = 0.5, target_sizes: Union[TensorType, List[Tuple]] = None, top_k: int = 100
):
"""
Converts the raw output of [`ConditionalDetrForObjectDetection`] into final bounding boxes in (top_left_x,
......@@ -1342,6 +1342,8 @@ class ConditionalDetrImageProcessor(BaseImageProcessor):
target_sizes (`torch.Tensor` or `List[Tuple[int, int]]`, *optional*):
Tensor of shape `(batch_size, 2)` or list of tuples (`Tuple[int, int]`) containing the target size
(height, width) of each image in the batch. If left to None, predictions will not be resized.
top_k (`int`, *optional*, defaults to 100):
Keep only top k bounding boxes before filtering by thresholding.
Returns:
`List[Dict]`: A list of dictionaries, each dictionary containing the scores, labels and boxes for an image
......@@ -1356,7 +1358,9 @@ class ConditionalDetrImageProcessor(BaseImageProcessor):
)
prob = out_logits.sigmoid()
topk_values, topk_indexes = torch.topk(prob.view(out_logits.shape[0], -1), 100, dim=1)
prob = prob.view(out_logits.shape[0], -1)
k_value = min(top_k, prob.size(1))
topk_values, topk_indexes = torch.topk(prob, k_value, dim=1)
scores = topk_values
topk_boxes = torch.div(topk_indexes, out_logits.shape[2], rounding_mode="floor")
labels = topk_indexes % out_logits.shape[2]
......
......@@ -1325,7 +1325,7 @@ class DeformableDetrImageProcessor(BaseImageProcessor):
return results
def post_process_object_detection(
self, outputs, threshold: float = 0.5, target_sizes: Union[TensorType, List[Tuple]] = None
self, outputs, threshold: float = 0.5, target_sizes: Union[TensorType, List[Tuple]] = None, top_k: int = 100
):
"""
Converts the raw output of [`DeformableDetrForObjectDetection`] into final bounding boxes in (top_left_x,
......@@ -1339,6 +1339,8 @@ class DeformableDetrImageProcessor(BaseImageProcessor):
target_sizes (`torch.Tensor` or `List[Tuple[int, int]]`, *optional*):
Tensor of shape `(batch_size, 2)` or list of tuples (`Tuple[int, int]`) containing the target size
(height, width) of each image in the batch. If left to None, predictions will not be resized.
top_k (`int`, *optional*, defaults to 100):
Keep only top k bounding boxes before filtering by thresholding.
Returns:
`List[Dict]`: A list of dictionaries, each dictionary containing the scores, labels and boxes for an image
......@@ -1353,7 +1355,9 @@ class DeformableDetrImageProcessor(BaseImageProcessor):
)
prob = out_logits.sigmoid()
topk_values, topk_indexes = torch.topk(prob.view(out_logits.shape[0], -1), 100, dim=1)
prob = prob.view(out_logits.shape[0], -1)
k_value = min(top_k, prob.size(1))
topk_values, topk_indexes = torch.topk(prob, k_value, dim=1)
scores = topk_values
topk_boxes = torch.div(topk_indexes, out_logits.shape[2], rounding_mode="floor")
labels = topk_indexes % out_logits.shape[2]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment