Unverified Commit 01eb34ab authored by Alara Dirik's avatar Alara Dirik Committed by GitHub
Browse files

Improve DETR post-processing methods (#19205)

* Ensures consistent arguments and outputs with other post-processing methods
* Adds post_process_semantic_segmentation, post_process_instance_segmentation, post_process_panoptic_segmentation, post_process_object_detection methods to DetrFeatureExtractor
* Adds deprecation warnings to post_process, post_process_segmentation and post_process_panoptic
parent 655f72a6
......@@ -16,6 +16,7 @@
import io
import pathlib
import warnings
from collections import defaultdict
from typing import Dict, List, Optional, Union
......@@ -555,7 +556,7 @@ class ConditionalDetrFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtrac
if annotations is not None:
annotations = [annotations]
# Create deep copies to avoid editing inputs in place
# Create a copy of the list to avoid editing it in place
images = [image for image in images]
if annotations is not None:
......@@ -753,6 +754,11 @@ class ConditionalDetrFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtrac
`List[Dict]`: A list of dictionaries, each dictionary containing the scores, labels, and masks for an image
in the batch as predicted by the model.
"""
warnings.warn(
"`post_process_segmentation` is deprecated and will be removed in v5 of Transformers, please use"
" `post_process_semantic_segmentation`.",
FutureWarning,
)
out_logits, raw_masks = outputs.logits, outputs.pred_masks
preds = []
......@@ -801,6 +807,11 @@ class ConditionalDetrFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtrac
`List[Dict]`: A list of dictionaries, each dictionary containing the scores, labels, boxes and masks for an
image in the batch as predicted by the model.
"""
warnings.warn(
"`post_process_instance` is deprecated and will be removed in v5 of Transformers, please use"
" `post_process_instance_segmentation`.",
FutureWarning,
)
if len(orig_target_sizes) != len(max_target_sizes):
raise ValueError("Make sure to pass in as many orig_target_sizes as max_target_sizes")
......@@ -845,6 +856,11 @@ class ConditionalDetrFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtrac
`List[Dict]`: A list of dictionaries, each dictionary containing a PNG string and segments_info values for
an image in the batch as predicted by the model.
"""
warnings.warn(
"`post_process_panoptic is deprecated and will be removed in v5 of Transformers, please use"
" `post_process_panoptic_segmentation`.",
FutureWarning,
)
if target_sizes is None:
target_sizes = processed_sizes
if len(processed_sizes) != len(target_sizes):
......
......@@ -153,8 +153,8 @@ class ConditionalDetrObjectDetectionOutput(ModelOutput):
pred_boxes (`torch.FloatTensor` of shape `(batch_size, num_queries, 4)`):
Normalized boxes coordinates for all queries, represented as (center_x, center_y, width, height). These
values are normalized in [0, 1], relative to the size of each individual image in the batch (disregarding
possible padding). You can use [`~ConditionalDetrFeatureExtractor.post_process`] to retrieve the
unnormalized bounding boxes.
possible padding). You can use [`~ConditionalDetrFeatureExtractor.post_process_object_detection`] to
retrieve the unnormalized bounding boxes.
auxiliary_outputs (`list[Dict]`, *optional*):
Optional, only returned when auxilary losses are activated (i.e. `config.auxiliary_loss` is set to `True`)
and labels are provided. It is a list of dictionaries containing the two above keys (`logits` and
......@@ -217,13 +217,14 @@ class ConditionalDetrSegmentationOutput(ModelOutput):
pred_boxes (`torch.FloatTensor` of shape `(batch_size, num_queries, 4)`):
Normalized boxes coordinates for all queries, represented as (center_x, center_y, width, height). These
values are normalized in [0, 1], relative to the size of each individual image in the batch (disregarding
possible padding). You can use [`~ConditionalDetrFeatureExtractor.post_process`] to retrieve the
unnormalized bounding boxes.
possible padding). You can use [`~ConditionalDetrFeatureExtractor.post_process_object_detection`] to
retrieve the unnormalized bounding boxes.
pred_masks (`torch.FloatTensor` of shape `(batch_size, num_queries, height/4, width/4)`):
Segmentation masks logits for all queries. See also
[`~ConditionalDetrFeatureExtractor.post_process_segmentation`] or
[`~ConditionalDetrFeatureExtractor.post_process_panoptic`] to evaluate instance and panoptic segmentation
masks respectively.
[`~ConditionalDetrFeatureExtractor.post_process_semantic_segmentation`] or
[`~ConditionalDetrFeatureExtractor.post_process_instance_segmentation`]
[`~ConditionalDetrFeatureExtractor.post_process_panoptic_segmentation`] to evaluate semantic, instance and
panoptic segmentation masks respectively.
auxiliary_outputs (`list[Dict]`, *optional*):
Optional, only returned when auxiliary losses are activated (i.e. `config.auxiliary_loss` is set to `True`)
and labels are provided. It is a list of dictionaries containing the two above keys (`logits` and
......
......@@ -16,6 +16,7 @@
import io
import pathlib
import warnings
from collections import defaultdict
from typing import Dict, List, Optional, Union
......@@ -555,7 +556,7 @@ class DeformableDetrFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtract
if annotations is not None:
annotations = [annotations]
# Create deep copies to avoid editing inputs in place
# Create a copy of the list to avoid editing it in place
images = [image for image in images]
if annotations is not None:
......@@ -750,6 +751,11 @@ class DeformableDetrFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtract
`List[Dict]`: A list of dictionaries, each dictionary containing the scores, labels, and masks for an image
in the batch as predicted by the model.
"""
warnings.warn(
"`post_process_segmentation` is deprecated and will be removed in v5 of Transformers, please use"
" `post_process_semantic_segmentation`.",
FutureWarning,
)
out_logits, raw_masks = outputs.logits, outputs.pred_masks
preds = []
......@@ -798,6 +804,11 @@ class DeformableDetrFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtract
`List[Dict]`: A list of dictionaries, each dictionary containing the scores, labels, boxes and masks for an
image in the batch as predicted by the model.
"""
warnings.warn(
"`post_process_instance` is deprecated and will be removed in v5 of Transformers, please use"
" `post_process_instance_segmentation`.",
FutureWarning,
)
if len(orig_target_sizes) != len(max_target_sizes):
raise ValueError("Make sure to pass in as many orig_target_sizes as max_target_sizes")
......@@ -842,6 +853,11 @@ class DeformableDetrFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtract
`List[Dict]`: A list of dictionaries, each dictionary containing a PNG string and segments_info values for
an image in the batch as predicted by the model.
"""
warnings.warn(
"`post_process_panoptic is deprecated and will be removed in v5 of Transformers, please use"
" `post_process_panoptic_segmentation`.",
FutureWarning,
)
if target_sizes is None:
target_sizes = processed_sizes
if len(processed_sizes) != len(target_sizes):
......
......@@ -16,8 +16,9 @@
import io
import pathlib
import warnings
from collections import defaultdict
from typing import Dict, List, Optional, Union
from typing import Dict, List, Optional, Set, Tuple, Union
import numpy as np
from PIL import Image
......@@ -119,6 +120,54 @@ def id_to_rgb(id_map):
return color
def binary_mask_to_rle(mask):
"""
Args:
Converts given binary mask of shape (height, width) to the run-length encoding (RLE) format.
mask (`torch.Tensor` or `numpy.array`):
A binary mask tensor of shape `(height, width)` where 0 denotes background and 1 denotes the target
segment_id or class_id.
Returns:
`List`: Run-length encoded list of the binary mask. Refer to COCO API for more information about the RLE
format.
"""
if is_torch_tensor(mask):
mask = mask.numpy()
pixels = mask.flatten()
pixels = np.concatenate([[0], pixels, [0]])
runs = np.where(pixels[1:] != pixels[:-1])[0] + 1
runs[1::2] -= runs[::2]
return [x for x in runs]
def remove_low_and_no_objects(masks, scores, labels, object_mask_threshold, num_labels):
"""
Args:
Binarize the given masks using `object_mask_threshold`, it returns the associated values of `masks`, `scores`
and `labels`.
masks (`torch.Tensor`):
A tensor of shape `(num_queries, height, width)`.
scores (`torch.Tensor`):
A tensor of shape `(num_queries)`.
labels (`torch.Tensor`):
A tensor of shape `(num_queries)`.
object_mask_threshold (`float`):
A number between 0 and 1 used to binarize the masks.
Raises:
`ValueError`: Raised when the first dimension doesn't match in all input tensors.
Returns:
`Tuple[`torch.Tensor`, `torch.Tensor`, `torch.Tensor`]`: The `masks`, `scores` and `labels` without the region
< `object_mask_threshold`.
"""
if not (masks.shape[0] == scores.shape[0] == labels.shape[0]):
raise ValueError("mask, scores and labels must have the same shape!")
to_keep = labels.ne(num_labels) & (scores > object_mask_threshold)
return masks[to_keep], scores[to_keep], labels[to_keep]
class DetrFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionMixin):
r"""
Constructs a DETR feature extractor.
......@@ -547,7 +596,7 @@ class DetrFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionMixin):
if annotations is not None:
annotations = [annotations]
# Create deep copies to avoid editing inputs in place
# Create a copy of the list to avoid editing it in place
images = [image for image in images]
if annotations is not None:
......@@ -699,6 +748,12 @@ class DetrFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionMixin):
`List[Dict]`: A list of dictionaries, each dictionary containing the scores, labels and boxes for an image
in the batch as predicted by the model.
"""
warnings.warn(
"`post_process` is deprecated and will be removed in v5 of Transformers, please use"
" `post_process_object_detection`",
FutureWarning,
)
out_logits, out_bbox = outputs.logits, outputs.pred_boxes
if len(out_logits) != len(target_sizes):
......@@ -717,7 +772,6 @@ class DetrFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionMixin):
boxes = boxes * scale_fct[:, None, :]
results = [{"scores": s, "labels": l, "boxes": b} for s, l, b in zip(scores, labels, boxes)]
return results
def post_process_segmentation(self, outputs, target_sizes, threshold=0.9, mask_threshold=0.5):
......@@ -738,6 +792,11 @@ class DetrFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionMixin):
`List[Dict]`: A list of dictionaries, each dictionary containing the scores, labels, and masks for an image
in the batch as predicted by the model.
"""
warnings.warn(
"`post_process_segmentation` is deprecated and will be removed in v5 of Transformers, please use"
" `post_process_semantic_segmentation`.",
FutureWarning,
)
out_logits, raw_masks = outputs.logits, outputs.pred_masks
preds = []
......@@ -786,6 +845,11 @@ class DetrFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionMixin):
`List[Dict]`: A list of dictionaries, each dictionary containing the scores, labels, boxes and masks for an
image in the batch as predicted by the model.
"""
warnings.warn(
"`post_process_instance` is deprecated and will be removed in v5 of Transformers, please use"
" `post_process_instance_segmentation`.",
FutureWarning,
)
if len(orig_target_sizes) != len(max_target_sizes):
raise ValueError("Make sure to pass in as many orig_target_sizes as max_target_sizes")
......@@ -829,6 +893,11 @@ class DetrFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionMixin):
`List[Dict]`: A list of dictionaries, each dictionary containing a PNG string and segments_info values for
an image in the batch as predicted by the model.
"""
warnings.warn(
"`post_process_panoptic is deprecated and will be removed in v5 of Transformers, please use"
" `post_process_panoptic_segmentation`.",
FutureWarning,
)
if target_sizes is None:
target_sizes = processed_sizes
if len(processed_sizes) != len(target_sizes):
......@@ -939,3 +1008,351 @@ class DetrFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionMixin):
predictions = {"png_string": out.getvalue(), "segments_info": segments_info}
preds.append(predictions)
return preds
def post_process_object_detection(
self, outputs, threshold: float = 0.5, target_sizes: Union[TensorType, List[Tuple]] = None
):
"""
Converts the output of [`DetrForObjectDetection`] into the format expected by the COCO api. Only supports
PyTorch.
Args:
outputs ([`DetrObjectDetectionOutput`]):
Raw outputs of the model.
threshold (`float`, *optional*):
Score threshold to keep object detection predictions.
target_sizes (`torch.Tensor` or `List[Tuple[int, int]]`, *optional*, defaults to `None`):
Tensor of shape `(batch_size, 2)` or list of tuples (`Tuple[int, int]`) containing the target size
(height, width) of each image in the batch. If left to None, predictions will not be resized.
Returns:
`List[Dict]`: A list of dictionaries, each dictionary containing the scores, labels and boxes for an image
in the batch as predicted by the model.
"""
out_logits, out_bbox = outputs.logits, outputs.pred_boxes
if target_sizes is not None:
if len(out_logits) != len(target_sizes):
raise ValueError(
"Make sure that you pass in as many target sizes as the batch dimension of the logits"
)
prob = nn.functional.softmax(out_logits, -1)
scores, labels = prob[..., :-1].max(-1)
# Convert to [x0, y0, x1, y1] format
boxes = center_to_corners_format(out_bbox)
# Convert from relative [0, 1] to absolute [0, height] coordinates
if target_sizes is not None:
if isinstance(target_sizes, List):
img_h = torch.Tensor([i[0] for i in target_sizes])
img_w = torch.Tensor([i[1] for i in target_sizes])
else:
img_h, img_w = target_sizes.unbind(1)
scale_fct = torch.stack([img_w, img_h, img_w, img_h], dim=1)
boxes = boxes * scale_fct[:, None, :]
results = []
for s, l, b in zip(scores, labels, boxes):
score = s[s > threshold]
label = l[s > threshold]
box = b[s > threshold]
results.append({"scores": score, "labels": label, "boxes": box})
return results
def post_process_semantic_segmentation(self, outputs, target_sizes: List[Tuple[int, int]] = None):
"""
Args:
Converts the output of [`DetrForSegmentation`] into semantic segmentation maps. Only supports PyTorch.
outputs ([`DetrForSegmentation`]):
Raw outputs of the model.
target_sizes (`List[Tuple[int, int]]`, *optional*, defaults to `None`):
A list of tuples (`Tuple[int, int]`) containing the target size (height, width) of each image in the
batch. If left to None, predictions will not be resized.
Returns:
`List[torch.Tensor]`:
A list of length `batch_size`, where each item is a semantic segmentation map of shape (height, width)
corresponding to the target_sizes entry (if `target_sizes` is specified). Each entry of each
`torch.Tensor` correspond to a semantic class id.
"""
class_queries_logits = outputs.logits # [batch_size, num_queries, num_classes+1]
masks_queries_logits = outputs.pred_masks # [batch_size, num_queries, height, width]
# Remove the null class `[..., :-1]`
masks_classes = class_queries_logits.softmax(dim=-1)[..., :-1]
masks_probs = masks_queries_logits.sigmoid() # [batch_size, num_queries, height, width]
# Semantic segmentation logits of shape (batch_size, num_classes, height, width)
segmentation = torch.einsum("bqc, bqhw -> bchw", masks_classes, masks_probs)
batch_size = class_queries_logits.shape[0]
# Resize logits and compute semantic segmentation maps
if target_sizes is not None:
if batch_size != len(target_sizes):
raise ValueError(
"Make sure that you pass in as many target sizes as the batch dimension of the logits"
)
semantic_segmentation = []
for idx in range(batch_size):
resized_logits = torch.nn.functional.interpolate(
segmentation[idx].unsqueeze(dim=0), size=target_sizes[idx], mode="bilinear", align_corners=False
)
semantic_map = resized_logits[0].argmax(dim=0)
semantic_segmentation.append(semantic_map)
else:
semantic_segmentation = segmentation.argmax(dim=1)
semantic_segmentation = [semantic_segmentation[i] for i in range(semantic_segmentation.shape[0])]
return semantic_segmentation
def post_process_instance_segmentation(
self,
outputs,
threshold: float = 0.5,
overlap_mask_area_threshold: float = 0.8,
target_sizes: List[Tuple] = None,
return_coco_annotation: Optional[bool] = False,
):
"""
Args:
Converts the output of [`DetrForSegmentation`] into instance segmentation predictions. Only supports PyTorch.
outputs ([`DetrForSegmentation`]):
Raw outputs of the model.
threshold (`float`, *optional*):
The probability score threshold to keep predicted instance masks, defaults to 0.5.
overlap_mask_area_threshold (`float`, *optional*):
The overlap mask area threshold to merge or discard small disconnected parts within each binary
instance mask, defaults to 0.8.
target_sizes (`List[Tuple]`, *optional*, defaults to `None`):
List of length (batch_size), where each list item (`Tuple[int, int]]`) corresponds to the requested
final size (height, width) of each prediction. If left to None, predictions will not be resized.
return_coco_annotation (`bool`, *optional*, defaults to `False`):
If set to `True`, segmentation maps are returned in COCO run-length encoding (RLE) format.
Returns:
`List[Dict]`: A list of dictionaries, one per image, each dictionary containing two keys:
- **segmentation** -- A tensor of shape `(height, width)` where each pixel represents a `segment_id` or
`List[List]` run-length encoding (RLE) of the segmentation map if return_coco_format is set to `True`.
- **segment_ids** -- A dictionary that maps segment ids to semantic class ids.
- **id** -- An integer representing the `segment_id`.
- **label_id** -- An integer representing the segment's label / semantic class id.
"""
class_queries_logits = outputs.logits # [batch_size, num_queries, num_classes+1]
masks_queries_logits = outputs.pred_masks # [batch_size, num_queries, height, width]
batch_size = class_queries_logits.shape[0]
num_labels = class_queries_logits.shape[-1] - 1
mask_probs = masks_queries_logits.sigmoid() # [batch_size, num_queries, height, width]
# Predicted label and score of each query (batch_size, num_queries)
pred_scores, pred_labels = nn.functional.softmax(class_queries_logits, dim=-1).max(-1)
# Loop over items in batch size
results: List[Dict[str, TensorType]] = []
for i in range(batch_size):
mask_probs_item, pred_scores_item, pred_labels_item = remove_low_and_no_objects(
mask_probs[i], pred_scores[i], pred_labels[i], threshold, num_labels
)
height, width = target_sizes[i][0], target_sizes[i][1]
segmentation = torch.zeros((height, width), dtype=torch.int32, device=mask_probs_item.device)
segments: List[Dict] = []
object_detected = mask_probs_item.shape[0] > 0
if object_detected:
# Resize mask to corresponding target_size
if target_sizes is not None:
mask_probs_item = torch.nn.functional.interpolate(
mask_probs_item.unsqueeze(0),
size=target_sizes[i],
mode="bilinear",
align_corners=False,
)[0]
current_segment_id = 0
# Weigh each mask by its prediction score
mask_probs_item *= pred_scores_item.view(-1, 1, 1)
mask_labels_item = mask_probs_item.argmax(0) # [height, width]
# Keep track of instances of each class
stuff_memory_list: Dict[str, int] = {}
for k in range(pred_labels_item.shape[0]):
# Get the mask associated with the k class
pred_class = pred_labels_item[k].item()
mask_k = mask_labels_item == k
mask_k_area = mask_k.sum()
# Compute the area of all the stuff in query k
original_area = (mask_probs_item[k] >= 0.5).sum()
mask_exists = mask_k_area > 0 and original_area > 0
if mask_exists:
# Eliminate segments with mask area below threshold
area_ratio = mask_k_area / original_area
if not area_ratio.item() > overlap_mask_area_threshold:
continue
# Add corresponding class id
if pred_class in stuff_memory_list:
current_segment_id = stuff_memory_list[pred_class]
else:
current_segment_id += 1
# Add current object segment to final segmentation map
segmentation[mask_k] = current_segment_id
segments.append(
{
"id": current_segment_id,
"label_id": pred_class,
}
)
else:
segmentation -= 1
# Return segmentation map in run-length encoding (RLE) format
if return_coco_annotation:
segment_ids = torch.unique(segmentation)
run_length_encodings = []
for idx in segment_ids:
mask = torch.where(segmentation == idx, 1, 0)
rle = binary_mask_to_rle(mask)
run_length_encodings.append(rle)
segmentation = run_length_encodings
results.append({"segmentation": segmentation, "segment_ids": segments})
return results
def post_process_panoptic_segmentation(
self,
outputs,
threshold: float = 0.5,
overlap_mask_area_threshold: float = 0.8,
label_ids_to_fuse: Optional[Set[int]] = None,
target_sizes: List[Tuple] = None,
) -> List[Dict]:
"""
Args:
Converts the output of [`DetrForSegmentation`] into image panoptic segmentation predictions. Only supports
PyTorch.
outputs ([`DetrForSegmentation`]):
The outputs from [`DetrForSegmentation`].
threshold (`float`, *optional*, defaults to 0.5):
The probability score threshold to keep predicted instance masks.
overlap_mask_area_threshold (`float`, *optional*, defaults to 0.8):
The overlap mask area threshold to merge or discard small disconnected parts within each binary
instance mask.
label_ids_to_fuse (`Set[int]`, *optional*, defaults to `None`):
The labels in this state will have all their instances be fused together. For instance we could say
there can only be one sky in an image, but several persons, so the label ID for sky would be in that
set, but not the one for person.
target_sizes (`List[Tuple]`, *optional*):
List of length (batch_size), where each list item (`Tuple[int, int]]`) corresponds to the requested
final size (height, width) of each prediction in batch. If left to None, predictions will not be
resized.
Returns:
`List[Dict]`: A list of dictionaries, one per image, each dictionary containing two keys:
- **segmentation** -- a tensor of shape `(height, width)` where each pixel represents a `segment_id`. If
`target_sizes` is specified, segmentation is resized to the corresponding `target_sizes` entry.
- **segment_ids** -- A dictionary that maps segment ids to semantic class ids.
- **id** -- An integer representing the `segment_id`.
- **label_id** -- An integer representing the segment's label / semantic class id.
- **was_fused** -- a boolean, `True` if `label_id` was in `label_ids_to_fuse`, `False` otherwise.
Multiple instances of the same class / label were fused and assigned a single `segment_id`.
"""
if label_ids_to_fuse is None:
warnings.warn("`label_ids_to_fuse` unset. No instance will be fused.")
label_ids_to_fuse = set()
class_queries_logits = outputs.logits # [batch_size, num_queries, num_classes+1]
masks_queries_logits = outputs.pred_masks # [batch_size, num_queries, height, width]
batch_size = class_queries_logits.shape[0]
num_labels = class_queries_logits.shape[-1] - 1
mask_probs = masks_queries_logits.sigmoid() # [batch_size, num_queries, height, width]
# Predicted label and score of each query (batch_size, num_queries)
pred_scores, pred_labels = nn.functional.softmax(class_queries_logits, dim=-1).max(-1)
# Loop over items in batch size
results: List[Dict[str, TensorType]] = []
for i in range(batch_size):
mask_probs_item, pred_scores_item, pred_labels_item = remove_low_and_no_objects(
mask_probs[i], pred_scores[i], pred_labels[i], threshold, num_labels
)
height, width = target_sizes[i][0], target_sizes[i][1]
segmentation = torch.zeros((height, width), dtype=torch.int32, device=mask_probs_item.device)
segments: List[Dict] = []
object_detected = mask_probs_item.shape[0] > 0
if object_detected:
# Resize mask to corresponding target_size
if target_sizes is not None:
mask_probs_item = torch.nn.functional.interpolate(
mask_probs_item.unsqueeze(0),
size=target_sizes[i],
mode="bilinear",
align_corners=False,
)[0]
current_segment_id = 0
# Weigh each mask by its prediction score
mask_probs_item *= pred_scores_item.view(-1, 1, 1)
mask_labels_item = mask_probs_item.argmax(0) # [height, width]
# Keep track of instances of each class
stuff_memory_list: Dict[str, int] = {}
for k in range(pred_labels_item.shape[0]):
pred_class = pred_labels_item[k].item()
should_fuse = pred_class in label_ids_to_fuse
# Get the mask associated with the k class
mask_k = mask_labels_item == k
mask_k_area = mask_k.sum()
# Compute the area of all the stuff in query k
original_area = (mask_probs_item[k] >= 0.5).sum()
mask_exists = mask_k_area > 0 and original_area > 0
if mask_exists:
# Eliminate disconnected tiny segments
area_ratio = mask_k_area / original_area
if not area_ratio.item() > overlap_mask_area_threshold:
continue
# Add corresponding class id
if pred_class in stuff_memory_list:
current_segment_id = stuff_memory_list[pred_class]
else:
current_segment_id += 1
# Add current object segment to final segmentation map
segmentation[mask_k] = current_segment_id
segments.append(
{
"id": current_segment_id,
"label_id": pred_class,
"was_fused": should_fuse,
}
)
if should_fuse:
stuff_memory_list[pred_class] = current_segment_id
else:
segmentation -= 1
results.append({"segmentation": segmentation, "segment_ids": segments})
return results
......@@ -148,8 +148,8 @@ class DetrObjectDetectionOutput(ModelOutput):
pred_boxes (`torch.FloatTensor` of shape `(batch_size, num_queries, 4)`):
Normalized boxes coordinates for all queries, represented as (center_x, center_y, width, height). These
values are normalized in [0, 1], relative to the size of each individual image in the batch (disregarding
possible padding). You can use [`~DetrFeatureExtractor.post_process`] to retrieve the unnormalized bounding
boxes.
possible padding). You can use [`~DetrFeatureExtractor.post_process_object_detection`] to retrieve the
unnormalized bounding boxes.
auxiliary_outputs (`list[Dict]`, *optional*):
Optional, only returned when auxilary losses are activated (i.e. `config.auxiliary_loss` is set to `True`)
and labels are provided. It is a list of dictionaries containing the two above keys (`logits` and
......@@ -211,12 +211,14 @@ class DetrSegmentationOutput(ModelOutput):
pred_boxes (`torch.FloatTensor` of shape `(batch_size, num_queries, 4)`):
Normalized boxes coordinates for all queries, represented as (center_x, center_y, width, height). These
values are normalized in [0, 1], relative to the size of each individual image in the batch (disregarding
possible padding). You can use [`~DetrFeatureExtractor.post_process`] to retrieve the unnormalized bounding
boxes.
possible padding). You can use [`~DetrFeatureExtractor.post_process_object_detection`] to retrieve the
unnormalized bounding boxes.
pred_masks (`torch.FloatTensor` of shape `(batch_size, num_queries, height/4, width/4)`):
Segmentation masks logits for all queries. See also [`~DetrFeatureExtractor.post_process_segmentation`] or
[`~DetrFeatureExtractor.post_process_panoptic`] to evaluate instance and panoptic segmentation masks
respectively.
Segmentation masks logits for all queries. See also
[`~DetrFeatureExtractor.post_process_semantic_segmentation`] or
[`~DetrFeatureExtractor.post_process_instance_segmentation`]
[`~DetrFeatureExtractor.post_process_panoptic_segmentation`] to evaluate semantic, instance and panoptic
segmentation masks respectively.
auxiliary_outputs (`list[Dict]`, *optional*):
Optional, only returned when auxiliary losses are activated (i.e. `config.auxiliary_loss` is set to `True`)
and labels are provided. It is a list of dictionaries containing the two above keys (`logits` and
......@@ -1424,7 +1426,7 @@ class DetrForObjectDetection(DetrPreTrainedModel):
>>> # convert outputs (bounding boxes and class logits) to COCO API
>>> target_sizes = torch.tensor([image.size[::-1]])
>>> results = feature_extractor.post_process(outputs, target_sizes=target_sizes)[0]
>>> results = feature_extractor.post_process_object_detection(outputs, target_sizes=target_sizes)[0]
>>> for score, label, box in zip(results["scores"], results["labels"], results["boxes"]):
... box = [round(i, 2) for i in box.tolist()]
......@@ -1601,17 +1603,14 @@ class DetrForSegmentation(DetrPreTrainedModel):
>>> # forward pass
>>> outputs = model(**inputs)
>>> # use the `post_process_panoptic` method of `DetrFeatureExtractor` to convert to COCO format
>>> processed_sizes = torch.as_tensor(inputs["pixel_values"].shape[-2:]).unsqueeze(0)
>>> result = feature_extractor.post_process_panoptic(outputs, processed_sizes)[0]
>>> # the segmentation is stored in a special-format png
>>> panoptic_seg = Image.open(io.BytesIO(result["png_string"]))
>>> panoptic_seg = numpy.array(panoptic_seg, dtype=numpy.uint8)
>>> # retrieve the ids corresponding to each mask
>>> panoptic_seg_id = rgb_to_id(panoptic_seg)
>>> panoptic_seg_id.shape
(800, 1066)
>>> # Use the `post_process_panoptic_segmentation` method of `DetrFeatureExtractor` to retrieve post-processed panoptic segmentation maps
>>> # Segmentation results are returned as a list of dictionaries
>>> result = feature_extractor.post_process_panoptic_segmentation(outputs, processed_sizes)
>>> # A tensor of shape (height, width) where each value denotes a segment id
>>> panoptic_seg = result[0]["segmentation"]
>>> # Get mapping of segment ids to semantic class ids
>>> panoptic_segments_info = result[0]["segment_ids"]
```"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
......
......@@ -16,6 +16,7 @@
import io
import pathlib
import warnings
from collections import defaultdict
from typing import Dict, List, Optional, Union
......@@ -674,6 +675,12 @@ class YolosFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionMixin)
`List[Dict]`: A list of dictionaries, each dictionary containing the scores, labels and boxes for an image
in the batch as predicted by the model.
"""
warnings.warn(
"`post_process` is deprecated and will be removed in v5 of Transformers, please use"
" `post_process_object_detection`",
FutureWarning,
)
out_logits, out_bbox = outputs.logits, outputs.pred_boxes
if len(out_logits) != len(target_sizes):
......@@ -692,7 +699,6 @@ class YolosFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionMixin)
boxes = boxes * scale_fct[:, None, :]
results = [{"scores": s, "labels": l, "boxes": b} for s, l, b in zip(scores, labels, boxes)]
return results
# Copied from transformers.models.detr.feature_extraction_detr.DetrFeatureExtractor.post_process_segmentation
......@@ -714,6 +720,11 @@ class YolosFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionMixin)
`List[Dict]`: A list of dictionaries, each dictionary containing the scores, labels, and masks for an image
in the batch as predicted by the model.
"""
warnings.warn(
"`post_process_segmentation` is deprecated and will be removed in v5 of Transformers, please use"
" `post_process_semantic_segmentation`.",
FutureWarning,
)
out_logits, raw_masks = outputs.logits, outputs.pred_masks
preds = []
......@@ -762,6 +773,11 @@ class YolosFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionMixin)
`List[Dict]`: A list of dictionaries, each dictionary containing the scores, labels, boxes and masks for an
image in the batch as predicted by the model.
"""
warnings.warn(
"`post_process_instance` is deprecated and will be removed in v5 of Transformers, please use"
" `post_process_instance_segmentation`.",
FutureWarning,
)
if len(orig_target_sizes) != len(max_target_sizes):
raise ValueError("Make sure to pass in as many orig_target_sizes as max_target_sizes")
......@@ -805,6 +821,11 @@ class YolosFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionMixin)
`List[Dict]`: A list of dictionaries, each dictionary containing a PNG string and segments_info values for
an image in the batch as predicted by the model.
"""
warnings.warn(
"`post_process_panoptic is deprecated and will be removed in v5 of Transformers, please use"
" `post_process_panoptic_segmentation`.",
FutureWarning,
)
if target_sizes is None:
target_sizes = processed_sizes
if len(processed_sizes) != len(target_sizes):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment