Unverified Commit 01eb34ab authored by Alara Dirik's avatar Alara Dirik Committed by GitHub
Browse files

Improve DETR post-processing methods (#19205)

* Ensures consistent arguments and outputs with other post-processing methods
* Adds post_process_semantic_segmentation, post_process_instance_segmentation, post_process_panoptic_segmentation, post_process_object_detection methods to DetrFeatureExtractor
* Adds deprecation warnings to post_process, post_process_segmentation and post_process_panoptic
parent 655f72a6
...@@ -16,6 +16,7 @@ ...@@ -16,6 +16,7 @@
import io import io
import pathlib import pathlib
import warnings
from collections import defaultdict from collections import defaultdict
from typing import Dict, List, Optional, Union from typing import Dict, List, Optional, Union
...@@ -555,7 +556,7 @@ class ConditionalDetrFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtrac ...@@ -555,7 +556,7 @@ class ConditionalDetrFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtrac
if annotations is not None: if annotations is not None:
annotations = [annotations] annotations = [annotations]
# Create deep copies to avoid editing inputs in place # Create a copy of the list to avoid editing it in place
images = [image for image in images] images = [image for image in images]
if annotations is not None: if annotations is not None:
...@@ -753,6 +754,11 @@ class ConditionalDetrFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtrac ...@@ -753,6 +754,11 @@ class ConditionalDetrFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtrac
`List[Dict]`: A list of dictionaries, each dictionary containing the scores, labels, and masks for an image `List[Dict]`: A list of dictionaries, each dictionary containing the scores, labels, and masks for an image
in the batch as predicted by the model. in the batch as predicted by the model.
""" """
warnings.warn(
"`post_process_segmentation` is deprecated and will be removed in v5 of Transformers, please use"
" `post_process_semantic_segmentation`.",
FutureWarning,
)
out_logits, raw_masks = outputs.logits, outputs.pred_masks out_logits, raw_masks = outputs.logits, outputs.pred_masks
preds = [] preds = []
...@@ -801,6 +807,11 @@ class ConditionalDetrFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtrac ...@@ -801,6 +807,11 @@ class ConditionalDetrFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtrac
`List[Dict]`: A list of dictionaries, each dictionary containing the scores, labels, boxes and masks for an `List[Dict]`: A list of dictionaries, each dictionary containing the scores, labels, boxes and masks for an
image in the batch as predicted by the model. image in the batch as predicted by the model.
""" """
warnings.warn(
"`post_process_instance` is deprecated and will be removed in v5 of Transformers, please use"
" `post_process_instance_segmentation`.",
FutureWarning,
)
if len(orig_target_sizes) != len(max_target_sizes): if len(orig_target_sizes) != len(max_target_sizes):
raise ValueError("Make sure to pass in as many orig_target_sizes as max_target_sizes") raise ValueError("Make sure to pass in as many orig_target_sizes as max_target_sizes")
...@@ -845,6 +856,11 @@ class ConditionalDetrFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtrac ...@@ -845,6 +856,11 @@ class ConditionalDetrFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtrac
`List[Dict]`: A list of dictionaries, each dictionary containing a PNG string and segments_info values for `List[Dict]`: A list of dictionaries, each dictionary containing a PNG string and segments_info values for
an image in the batch as predicted by the model. an image in the batch as predicted by the model.
""" """
warnings.warn(
"`post_process_panoptic is deprecated and will be removed in v5 of Transformers, please use"
" `post_process_panoptic_segmentation`.",
FutureWarning,
)
if target_sizes is None: if target_sizes is None:
target_sizes = processed_sizes target_sizes = processed_sizes
if len(processed_sizes) != len(target_sizes): if len(processed_sizes) != len(target_sizes):
......
...@@ -153,8 +153,8 @@ class ConditionalDetrObjectDetectionOutput(ModelOutput): ...@@ -153,8 +153,8 @@ class ConditionalDetrObjectDetectionOutput(ModelOutput):
pred_boxes (`torch.FloatTensor` of shape `(batch_size, num_queries, 4)`): pred_boxes (`torch.FloatTensor` of shape `(batch_size, num_queries, 4)`):
Normalized boxes coordinates for all queries, represented as (center_x, center_y, width, height). These Normalized boxes coordinates for all queries, represented as (center_x, center_y, width, height). These
values are normalized in [0, 1], relative to the size of each individual image in the batch (disregarding values are normalized in [0, 1], relative to the size of each individual image in the batch (disregarding
possible padding). You can use [`~ConditionalDetrFeatureExtractor.post_process`] to retrieve the possible padding). You can use [`~ConditionalDetrFeatureExtractor.post_process_object_detection`] to
unnormalized bounding boxes. retrieve the unnormalized bounding boxes.
auxiliary_outputs (`list[Dict]`, *optional*): auxiliary_outputs (`list[Dict]`, *optional*):
Optional, only returned when auxilary losses are activated (i.e. `config.auxiliary_loss` is set to `True`) Optional, only returned when auxilary losses are activated (i.e. `config.auxiliary_loss` is set to `True`)
and labels are provided. It is a list of dictionaries containing the two above keys (`logits` and and labels are provided. It is a list of dictionaries containing the two above keys (`logits` and
...@@ -217,13 +217,14 @@ class ConditionalDetrSegmentationOutput(ModelOutput): ...@@ -217,13 +217,14 @@ class ConditionalDetrSegmentationOutput(ModelOutput):
pred_boxes (`torch.FloatTensor` of shape `(batch_size, num_queries, 4)`): pred_boxes (`torch.FloatTensor` of shape `(batch_size, num_queries, 4)`):
Normalized boxes coordinates for all queries, represented as (center_x, center_y, width, height). These Normalized boxes coordinates for all queries, represented as (center_x, center_y, width, height). These
values are normalized in [0, 1], relative to the size of each individual image in the batch (disregarding values are normalized in [0, 1], relative to the size of each individual image in the batch (disregarding
possible padding). You can use [`~ConditionalDetrFeatureExtractor.post_process`] to retrieve the possible padding). You can use [`~ConditionalDetrFeatureExtractor.post_process_object_detection`] to
unnormalized bounding boxes. retrieve the unnormalized bounding boxes.
pred_masks (`torch.FloatTensor` of shape `(batch_size, num_queries, height/4, width/4)`): pred_masks (`torch.FloatTensor` of shape `(batch_size, num_queries, height/4, width/4)`):
Segmentation masks logits for all queries. See also Segmentation masks logits for all queries. See also
[`~ConditionalDetrFeatureExtractor.post_process_segmentation`] or [`~ConditionalDetrFeatureExtractor.post_process_semantic_segmentation`] or
[`~ConditionalDetrFeatureExtractor.post_process_panoptic`] to evaluate instance and panoptic segmentation [`~ConditionalDetrFeatureExtractor.post_process_instance_segmentation`]
masks respectively. [`~ConditionalDetrFeatureExtractor.post_process_panoptic_segmentation`] to evaluate semantic, instance and
panoptic segmentation masks respectively.
auxiliary_outputs (`list[Dict]`, *optional*): auxiliary_outputs (`list[Dict]`, *optional*):
Optional, only returned when auxiliary losses are activated (i.e. `config.auxiliary_loss` is set to `True`) Optional, only returned when auxiliary losses are activated (i.e. `config.auxiliary_loss` is set to `True`)
and labels are provided. It is a list of dictionaries containing the two above keys (`logits` and and labels are provided. It is a list of dictionaries containing the two above keys (`logits` and
......
...@@ -16,6 +16,7 @@ ...@@ -16,6 +16,7 @@
import io import io
import pathlib import pathlib
import warnings
from collections import defaultdict from collections import defaultdict
from typing import Dict, List, Optional, Union from typing import Dict, List, Optional, Union
...@@ -555,7 +556,7 @@ class DeformableDetrFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtract ...@@ -555,7 +556,7 @@ class DeformableDetrFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtract
if annotations is not None: if annotations is not None:
annotations = [annotations] annotations = [annotations]
# Create deep copies to avoid editing inputs in place # Create a copy of the list to avoid editing it in place
images = [image for image in images] images = [image for image in images]
if annotations is not None: if annotations is not None:
...@@ -750,6 +751,11 @@ class DeformableDetrFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtract ...@@ -750,6 +751,11 @@ class DeformableDetrFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtract
`List[Dict]`: A list of dictionaries, each dictionary containing the scores, labels, and masks for an image `List[Dict]`: A list of dictionaries, each dictionary containing the scores, labels, and masks for an image
in the batch as predicted by the model. in the batch as predicted by the model.
""" """
warnings.warn(
"`post_process_segmentation` is deprecated and will be removed in v5 of Transformers, please use"
" `post_process_semantic_segmentation`.",
FutureWarning,
)
out_logits, raw_masks = outputs.logits, outputs.pred_masks out_logits, raw_masks = outputs.logits, outputs.pred_masks
preds = [] preds = []
...@@ -798,6 +804,11 @@ class DeformableDetrFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtract ...@@ -798,6 +804,11 @@ class DeformableDetrFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtract
`List[Dict]`: A list of dictionaries, each dictionary containing the scores, labels, boxes and masks for an `List[Dict]`: A list of dictionaries, each dictionary containing the scores, labels, boxes and masks for an
image in the batch as predicted by the model. image in the batch as predicted by the model.
""" """
warnings.warn(
"`post_process_instance` is deprecated and will be removed in v5 of Transformers, please use"
" `post_process_instance_segmentation`.",
FutureWarning,
)
if len(orig_target_sizes) != len(max_target_sizes): if len(orig_target_sizes) != len(max_target_sizes):
raise ValueError("Make sure to pass in as many orig_target_sizes as max_target_sizes") raise ValueError("Make sure to pass in as many orig_target_sizes as max_target_sizes")
...@@ -842,6 +853,11 @@ class DeformableDetrFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtract ...@@ -842,6 +853,11 @@ class DeformableDetrFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtract
`List[Dict]`: A list of dictionaries, each dictionary containing a PNG string and segments_info values for `List[Dict]`: A list of dictionaries, each dictionary containing a PNG string and segments_info values for
an image in the batch as predicted by the model. an image in the batch as predicted by the model.
""" """
warnings.warn(
"`post_process_panoptic is deprecated and will be removed in v5 of Transformers, please use"
" `post_process_panoptic_segmentation`.",
FutureWarning,
)
if target_sizes is None: if target_sizes is None:
target_sizes = processed_sizes target_sizes = processed_sizes
if len(processed_sizes) != len(target_sizes): if len(processed_sizes) != len(target_sizes):
......
...@@ -148,8 +148,8 @@ class DetrObjectDetectionOutput(ModelOutput): ...@@ -148,8 +148,8 @@ class DetrObjectDetectionOutput(ModelOutput):
pred_boxes (`torch.FloatTensor` of shape `(batch_size, num_queries, 4)`): pred_boxes (`torch.FloatTensor` of shape `(batch_size, num_queries, 4)`):
Normalized boxes coordinates for all queries, represented as (center_x, center_y, width, height). These Normalized boxes coordinates for all queries, represented as (center_x, center_y, width, height). These
values are normalized in [0, 1], relative to the size of each individual image in the batch (disregarding values are normalized in [0, 1], relative to the size of each individual image in the batch (disregarding
possible padding). You can use [`~DetrFeatureExtractor.post_process`] to retrieve the unnormalized bounding possible padding). You can use [`~DetrFeatureExtractor.post_process_object_detection`] to retrieve the
boxes. unnormalized bounding boxes.
auxiliary_outputs (`list[Dict]`, *optional*): auxiliary_outputs (`list[Dict]`, *optional*):
Optional, only returned when auxilary losses are activated (i.e. `config.auxiliary_loss` is set to `True`) Optional, only returned when auxilary losses are activated (i.e. `config.auxiliary_loss` is set to `True`)
and labels are provided. It is a list of dictionaries containing the two above keys (`logits` and and labels are provided. It is a list of dictionaries containing the two above keys (`logits` and
...@@ -211,12 +211,14 @@ class DetrSegmentationOutput(ModelOutput): ...@@ -211,12 +211,14 @@ class DetrSegmentationOutput(ModelOutput):
pred_boxes (`torch.FloatTensor` of shape `(batch_size, num_queries, 4)`): pred_boxes (`torch.FloatTensor` of shape `(batch_size, num_queries, 4)`):
Normalized boxes coordinates for all queries, represented as (center_x, center_y, width, height). These Normalized boxes coordinates for all queries, represented as (center_x, center_y, width, height). These
values are normalized in [0, 1], relative to the size of each individual image in the batch (disregarding values are normalized in [0, 1], relative to the size of each individual image in the batch (disregarding
possible padding). You can use [`~DetrFeatureExtractor.post_process`] to retrieve the unnormalized bounding possible padding). You can use [`~DetrFeatureExtractor.post_process_object_detection`] to retrieve the
boxes. unnormalized bounding boxes.
pred_masks (`torch.FloatTensor` of shape `(batch_size, num_queries, height/4, width/4)`): pred_masks (`torch.FloatTensor` of shape `(batch_size, num_queries, height/4, width/4)`):
Segmentation masks logits for all queries. See also [`~DetrFeatureExtractor.post_process_segmentation`] or Segmentation masks logits for all queries. See also
[`~DetrFeatureExtractor.post_process_panoptic`] to evaluate instance and panoptic segmentation masks [`~DetrFeatureExtractor.post_process_semantic_segmentation`] or
respectively. [`~DetrFeatureExtractor.post_process_instance_segmentation`]
[`~DetrFeatureExtractor.post_process_panoptic_segmentation`] to evaluate semantic, instance and panoptic
segmentation masks respectively.
auxiliary_outputs (`list[Dict]`, *optional*): auxiliary_outputs (`list[Dict]`, *optional*):
Optional, only returned when auxiliary losses are activated (i.e. `config.auxiliary_loss` is set to `True`) Optional, only returned when auxiliary losses are activated (i.e. `config.auxiliary_loss` is set to `True`)
and labels are provided. It is a list of dictionaries containing the two above keys (`logits` and and labels are provided. It is a list of dictionaries containing the two above keys (`logits` and
...@@ -1424,7 +1426,7 @@ class DetrForObjectDetection(DetrPreTrainedModel): ...@@ -1424,7 +1426,7 @@ class DetrForObjectDetection(DetrPreTrainedModel):
>>> # convert outputs (bounding boxes and class logits) to COCO API >>> # convert outputs (bounding boxes and class logits) to COCO API
>>> target_sizes = torch.tensor([image.size[::-1]]) >>> target_sizes = torch.tensor([image.size[::-1]])
>>> results = feature_extractor.post_process(outputs, target_sizes=target_sizes)[0] >>> results = feature_extractor.post_process_object_detection(outputs, target_sizes=target_sizes)[0]
>>> for score, label, box in zip(results["scores"], results["labels"], results["boxes"]): >>> for score, label, box in zip(results["scores"], results["labels"], results["boxes"]):
... box = [round(i, 2) for i in box.tolist()] ... box = [round(i, 2) for i in box.tolist()]
...@@ -1601,17 +1603,14 @@ class DetrForSegmentation(DetrPreTrainedModel): ...@@ -1601,17 +1603,14 @@ class DetrForSegmentation(DetrPreTrainedModel):
>>> # forward pass >>> # forward pass
>>> outputs = model(**inputs) >>> outputs = model(**inputs)
>>> # use the `post_process_panoptic` method of `DetrFeatureExtractor` to convert to COCO format >>> # Use the `post_process_panoptic_segmentation` method of `DetrFeatureExtractor` to retrieve post-processed panoptic segmentation maps
>>> processed_sizes = torch.as_tensor(inputs["pixel_values"].shape[-2:]).unsqueeze(0) >>> # Segmentation results are returned as a list of dictionaries
>>> result = feature_extractor.post_process_panoptic(outputs, processed_sizes)[0] >>> result = feature_extractor.post_process_panoptic_segmentation(outputs, processed_sizes)
>>> # the segmentation is stored in a special-format png >>> # A tensor of shape (height, width) where each value denotes a segment id
>>> panoptic_seg = Image.open(io.BytesIO(result["png_string"])) >>> panoptic_seg = result[0]["segmentation"]
>>> panoptic_seg = numpy.array(panoptic_seg, dtype=numpy.uint8) >>> # Get mapping of segment ids to semantic class ids
>>> # retrieve the ids corresponding to each mask >>> panoptic_segments_info = result[0]["segment_ids"]
>>> panoptic_seg_id = rgb_to_id(panoptic_seg)
>>> panoptic_seg_id.shape
(800, 1066)
```""" ```"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict return_dict = return_dict if return_dict is not None else self.config.use_return_dict
......
...@@ -16,6 +16,7 @@ ...@@ -16,6 +16,7 @@
import io import io
import pathlib import pathlib
import warnings
from collections import defaultdict from collections import defaultdict
from typing import Dict, List, Optional, Union from typing import Dict, List, Optional, Union
...@@ -674,6 +675,12 @@ class YolosFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionMixin) ...@@ -674,6 +675,12 @@ class YolosFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionMixin)
`List[Dict]`: A list of dictionaries, each dictionary containing the scores, labels and boxes for an image `List[Dict]`: A list of dictionaries, each dictionary containing the scores, labels and boxes for an image
in the batch as predicted by the model. in the batch as predicted by the model.
""" """
warnings.warn(
"`post_process` is deprecated and will be removed in v5 of Transformers, please use"
" `post_process_object_detection`",
FutureWarning,
)
out_logits, out_bbox = outputs.logits, outputs.pred_boxes out_logits, out_bbox = outputs.logits, outputs.pred_boxes
if len(out_logits) != len(target_sizes): if len(out_logits) != len(target_sizes):
...@@ -692,7 +699,6 @@ class YolosFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionMixin) ...@@ -692,7 +699,6 @@ class YolosFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionMixin)
boxes = boxes * scale_fct[:, None, :] boxes = boxes * scale_fct[:, None, :]
results = [{"scores": s, "labels": l, "boxes": b} for s, l, b in zip(scores, labels, boxes)] results = [{"scores": s, "labels": l, "boxes": b} for s, l, b in zip(scores, labels, boxes)]
return results return results
# Copied from transformers.models.detr.feature_extraction_detr.DetrFeatureExtractor.post_process_segmentation # Copied from transformers.models.detr.feature_extraction_detr.DetrFeatureExtractor.post_process_segmentation
...@@ -714,6 +720,11 @@ class YolosFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionMixin) ...@@ -714,6 +720,11 @@ class YolosFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionMixin)
`List[Dict]`: A list of dictionaries, each dictionary containing the scores, labels, and masks for an image `List[Dict]`: A list of dictionaries, each dictionary containing the scores, labels, and masks for an image
in the batch as predicted by the model. in the batch as predicted by the model.
""" """
warnings.warn(
"`post_process_segmentation` is deprecated and will be removed in v5 of Transformers, please use"
" `post_process_semantic_segmentation`.",
FutureWarning,
)
out_logits, raw_masks = outputs.logits, outputs.pred_masks out_logits, raw_masks = outputs.logits, outputs.pred_masks
preds = [] preds = []
...@@ -762,6 +773,11 @@ class YolosFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionMixin) ...@@ -762,6 +773,11 @@ class YolosFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionMixin)
`List[Dict]`: A list of dictionaries, each dictionary containing the scores, labels, boxes and masks for an `List[Dict]`: A list of dictionaries, each dictionary containing the scores, labels, boxes and masks for an
image in the batch as predicted by the model. image in the batch as predicted by the model.
""" """
warnings.warn(
"`post_process_instance` is deprecated and will be removed in v5 of Transformers, please use"
" `post_process_instance_segmentation`.",
FutureWarning,
)
if len(orig_target_sizes) != len(max_target_sizes): if len(orig_target_sizes) != len(max_target_sizes):
raise ValueError("Make sure to pass in as many orig_target_sizes as max_target_sizes") raise ValueError("Make sure to pass in as many orig_target_sizes as max_target_sizes")
...@@ -805,6 +821,11 @@ class YolosFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionMixin) ...@@ -805,6 +821,11 @@ class YolosFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionMixin)
`List[Dict]`: A list of dictionaries, each dictionary containing a PNG string and segments_info values for `List[Dict]`: A list of dictionaries, each dictionary containing a PNG string and segments_info values for
an image in the batch as predicted by the model. an image in the batch as predicted by the model.
""" """
warnings.warn(
"`post_process_panoptic is deprecated and will be removed in v5 of Transformers, please use"
" `post_process_panoptic_segmentation`.",
FutureWarning,
)
if target_sizes is None: if target_sizes is None:
target_sizes = processed_sizes target_sizes = processed_sizes
if len(processed_sizes) != len(target_sizes): if len(processed_sizes) != len(target_sizes):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment