"...git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "39acfe84ba330fda3ae72c083284a04cac8ac9e0"
Unverified Commit b92abfa6 authored by Alexander Brokking's avatar Alexander Brokking Committed by GitHub
Browse files

Add `top_k` argument to post-process of conditional/deformable-DETR (#22787)

* update min k_value of conditional detr post-processing

* feat: add top_k arg to post processing of deformable and conditional detr

* refactor: revert changes to deprecated methods

* refactor: move prob reshape to improve code clarity and reduce repetition
parent f82ee109
...@@ -1328,7 +1328,7 @@ class ConditionalDetrImageProcessor(BaseImageProcessor): ...@@ -1328,7 +1328,7 @@ class ConditionalDetrImageProcessor(BaseImageProcessor):
# Copied from transformers.models.deformable_detr.image_processing_deformable_detr.DeformableDetrImageProcessor.post_process_object_detection with DeformableDetr->ConditionalDetr # Copied from transformers.models.deformable_detr.image_processing_deformable_detr.DeformableDetrImageProcessor.post_process_object_detection with DeformableDetr->ConditionalDetr
def post_process_object_detection( def post_process_object_detection(
self, outputs, threshold: float = 0.5, target_sizes: Union[TensorType, List[Tuple]] = None self, outputs, threshold: float = 0.5, target_sizes: Union[TensorType, List[Tuple]] = None, top_k: int = 100
): ):
""" """
Converts the raw output of [`ConditionalDetrForObjectDetection`] into final bounding boxes in (top_left_x, Converts the raw output of [`ConditionalDetrForObjectDetection`] into final bounding boxes in (top_left_x,
...@@ -1342,6 +1342,8 @@ class ConditionalDetrImageProcessor(BaseImageProcessor): ...@@ -1342,6 +1342,8 @@ class ConditionalDetrImageProcessor(BaseImageProcessor):
target_sizes (`torch.Tensor` or `List[Tuple[int, int]]`, *optional*): target_sizes (`torch.Tensor` or `List[Tuple[int, int]]`, *optional*):
Tensor of shape `(batch_size, 2)` or list of tuples (`Tuple[int, int]`) containing the target size Tensor of shape `(batch_size, 2)` or list of tuples (`Tuple[int, int]`) containing the target size
(height, width) of each image in the batch. If left to None, predictions will not be resized. (height, width) of each image in the batch. If left to None, predictions will not be resized.
top_k (`int`, *optional*, defaults to 100):
Keep only top k bounding boxes before filtering by thresholding.
Returns: Returns:
`List[Dict]`: A list of dictionaries, each dictionary containing the scores, labels and boxes for an image `List[Dict]`: A list of dictionaries, each dictionary containing the scores, labels and boxes for an image
...@@ -1356,7 +1358,9 @@ class ConditionalDetrImageProcessor(BaseImageProcessor): ...@@ -1356,7 +1358,9 @@ class ConditionalDetrImageProcessor(BaseImageProcessor):
) )
prob = out_logits.sigmoid() prob = out_logits.sigmoid()
topk_values, topk_indexes = torch.topk(prob.view(out_logits.shape[0], -1), 100, dim=1) prob = prob.view(out_logits.shape[0], -1)
k_value = min(top_k, prob.size(1))
topk_values, topk_indexes = torch.topk(prob, k_value, dim=1)
scores = topk_values scores = topk_values
topk_boxes = torch.div(topk_indexes, out_logits.shape[2], rounding_mode="floor") topk_boxes = torch.div(topk_indexes, out_logits.shape[2], rounding_mode="floor")
labels = topk_indexes % out_logits.shape[2] labels = topk_indexes % out_logits.shape[2]
......
...@@ -1325,7 +1325,7 @@ class DeformableDetrImageProcessor(BaseImageProcessor): ...@@ -1325,7 +1325,7 @@ class DeformableDetrImageProcessor(BaseImageProcessor):
return results return results
def post_process_object_detection( def post_process_object_detection(
self, outputs, threshold: float = 0.5, target_sizes: Union[TensorType, List[Tuple]] = None self, outputs, threshold: float = 0.5, target_sizes: Union[TensorType, List[Tuple]] = None, top_k: int = 100
): ):
""" """
Converts the raw output of [`DeformableDetrForObjectDetection`] into final bounding boxes in (top_left_x, Converts the raw output of [`DeformableDetrForObjectDetection`] into final bounding boxes in (top_left_x,
...@@ -1339,6 +1339,8 @@ class DeformableDetrImageProcessor(BaseImageProcessor): ...@@ -1339,6 +1339,8 @@ class DeformableDetrImageProcessor(BaseImageProcessor):
target_sizes (`torch.Tensor` or `List[Tuple[int, int]]`, *optional*): target_sizes (`torch.Tensor` or `List[Tuple[int, int]]`, *optional*):
Tensor of shape `(batch_size, 2)` or list of tuples (`Tuple[int, int]`) containing the target size Tensor of shape `(batch_size, 2)` or list of tuples (`Tuple[int, int]`) containing the target size
(height, width) of each image in the batch. If left to None, predictions will not be resized. (height, width) of each image in the batch. If left to None, predictions will not be resized.
top_k (`int`, *optional*, defaults to 100):
Keep only top k bounding boxes before filtering by thresholding.
Returns: Returns:
`List[Dict]`: A list of dictionaries, each dictionary containing the scores, labels and boxes for an image `List[Dict]`: A list of dictionaries, each dictionary containing the scores, labels and boxes for an image
...@@ -1353,7 +1355,9 @@ class DeformableDetrImageProcessor(BaseImageProcessor): ...@@ -1353,7 +1355,9 @@ class DeformableDetrImageProcessor(BaseImageProcessor):
) )
prob = out_logits.sigmoid() prob = out_logits.sigmoid()
topk_values, topk_indexes = torch.topk(prob.view(out_logits.shape[0], -1), 100, dim=1) prob = prob.view(out_logits.shape[0], -1)
k_value = min(top_k, prob.size(1))
topk_values, topk_indexes = torch.topk(prob, k_value, dim=1)
scores = topk_values scores = topk_values
topk_boxes = torch.div(topk_indexes, out_logits.shape[2], rounding_mode="floor") topk_boxes = torch.div(topk_indexes, out_logits.shape[2], rounding_mode="floor")
labels = topk_indexes % out_logits.shape[2] labels = topk_indexes % out_logits.shape[2]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment