Unverified Commit 36f52e95 authored by Alara Dirik's avatar Alara Dirik Committed by GitHub
Browse files

Restructure DETR post-processing, return prediction scores (#19262)

* Restructure DetrFeatureExtractor post-processing methods
* Update post_process_instance_segmentation and post_process_panoptic_segmentation methods to return prediction scores
* Update DETR models docs
parent 5cd16f01
...@@ -171,9 +171,9 @@ mean Average Precision (mAP) and Panoptic Quality (PQ). The latter objects are i ...@@ -171,9 +171,9 @@ mean Average Precision (mAP) and Panoptic Quality (PQ). The latter objects are i
[[autodoc]] DetrFeatureExtractor [[autodoc]] DetrFeatureExtractor
- __call__ - __call__
- pad_and_create_pixel_mask - pad_and_create_pixel_mask
- post_process - post_process_semantic_segmentation
- post_process_segmentation - post_process_instance_segmentation
- post_process_panoptic - post_process_panoptic_segmentation
## DetrModel ## DetrModel
......
...@@ -141,11 +141,33 @@ def binary_mask_to_rle(mask): ...@@ -141,11 +141,33 @@ def binary_mask_to_rle(mask):
return [x for x in runs] return [x for x in runs]
def convert_segmentation_to_rle(segmentation):
"""
Converts given segmentation map of shape (height, width) to the run-length encoding (RLE) format.
Args:
segmentation (`torch.Tensor` or `numpy.array`):
A segmentation map of shape `(height, width)` where each value denotes a segment or class id.
Returns:
`List[List]`: A list of lists, where each list is the run-length encoding of a segment / class id.
"""
segment_ids = torch.unique(segmentation)
run_length_encodings = []
for idx in segment_ids:
mask = torch.where(segmentation == idx, 1, 0)
rle = binary_mask_to_rle(mask)
run_length_encodings.append(rle)
return run_length_encodings
def remove_low_and_no_objects(masks, scores, labels, object_mask_threshold, num_labels): def remove_low_and_no_objects(masks, scores, labels, object_mask_threshold, num_labels):
""" """
Binarize the given masks using `object_mask_threshold`, it returns the associated values of `masks`, `scores` and
`labels`.
Args: Args:
Binarize the given masks using `object_mask_threshold`, it returns the associated values of `masks`, `scores`
and `labels`.
masks (`torch.Tensor`): masks (`torch.Tensor`):
A tensor of shape `(num_queries, height, width)`. A tensor of shape `(num_queries, height, width)`.
scores (`torch.Tensor`): scores (`torch.Tensor`):
...@@ -168,6 +190,81 @@ def remove_low_and_no_objects(masks, scores, labels, object_mask_threshold, num_ ...@@ -168,6 +190,81 @@ def remove_low_and_no_objects(masks, scores, labels, object_mask_threshold, num_
return masks[to_keep], scores[to_keep], labels[to_keep] return masks[to_keep], scores[to_keep], labels[to_keep]
def check_segment_validity(mask_labels, mask_probs, k, overlap_mask_area_threshold=0.8):
# Get the mask associated with the k class
mask_k = mask_labels == k
mask_k_area = mask_k.sum()
# Compute the area of all the stuff in query k
original_area = (mask_probs[k] >= 0.5).sum()
mask_exists = mask_k_area > 0 and original_area > 0
# Eliminate disconnected tiny segments
if mask_exists:
area_ratio = mask_k_area / original_area
if not area_ratio.item() > overlap_mask_area_threshold:
mask_exists = False
return mask_exists, mask_k
def compute_segments(
mask_probs,
pred_scores,
pred_labels,
overlap_mask_area_threshold: float = 0.8,
label_ids_to_fuse: Optional[Set[int]] = None,
target_size: Tuple[int, int] = None,
):
height = mask_probs.shape[1] if target_size is None else target_size[0]
width = mask_probs.shape[2] if target_size is None else target_size[1]
segmentation = torch.zeros((height, width), dtype=torch.int32, device=mask_probs.device)
segments: List[Dict] = []
if target_size is not None:
mask_probs = nn.functional.interpolate(
mask_probs.unsqueeze(0), size=target_size, mode="bilinear", align_corners=False
)[0]
current_segment_id = 0
# Weigh each mask by its prediction score
mask_probs *= pred_scores.view(-1, 1, 1)
mask_labels = mask_probs.argmax(0) # [height, width]
# Keep track of instances of each class
stuff_memory_list: Dict[str, int] = {}
for k in range(pred_labels.shape[0]):
pred_class = pred_labels[k].item()
should_fuse = pred_class in label_ids_to_fuse
# Check if mask exists and large enough to be a segment
mask_exists, mask_k = check_segment_validity(mask_labels, mask_probs, k, overlap_mask_area_threshold)
if mask_exists:
if pred_class in stuff_memory_list:
current_segment_id = stuff_memory_list[pred_class]
else:
current_segment_id += 1
# Add current object segment to final segmentation map
segmentation[mask_k] = current_segment_id
segment_score = round(pred_scores[k].item(), 6)
segments.append(
{
"id": current_segment_id,
"label_id": pred_class,
"was_fused": should_fuse,
"score": segment_score,
}
)
if should_fuse:
stuff_memory_list[pred_class] = current_segment_id
return segmentation, segments
class DetrFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionMixin): class DetrFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionMixin):
r""" r"""
Constructs a DETR feature extractor. Constructs a DETR feature extractor.
...@@ -1098,7 +1195,7 @@ class DetrFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionMixin): ...@@ -1098,7 +1195,7 @@ class DetrFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionMixin):
semantic_segmentation = [] semantic_segmentation = []
for idx in range(batch_size): for idx in range(batch_size):
resized_logits = torch.nn.functional.interpolate( resized_logits = nn.functional.interpolate(
segmentation[idx].unsqueeze(dim=0), size=target_sizes[idx], mode="bilinear", align_corners=False segmentation[idx].unsqueeze(dim=0), size=target_sizes[idx], mode="bilinear", align_corners=False
) )
semantic_map = resized_logits[0].argmax(dim=0) semantic_map = resized_logits[0].argmax(dim=0)
...@@ -1114,31 +1211,34 @@ class DetrFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionMixin): ...@@ -1114,31 +1211,34 @@ class DetrFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionMixin):
outputs, outputs,
threshold: float = 0.5, threshold: float = 0.5,
overlap_mask_area_threshold: float = 0.8, overlap_mask_area_threshold: float = 0.8,
target_sizes: List[Tuple] = None, target_sizes: Optional[List[Tuple[int, int]]] = None,
return_coco_annotation: Optional[bool] = False, return_coco_annotation: Optional[bool] = False,
): ) -> List[Dict]:
""" """
Args: Args:
Converts the output of [`DetrForSegmentation`] into instance segmentation predictions. Only supports PyTorch. Converts the output of [`DetrForSegmentation`] into instance segmentation predictions. Only supports PyTorch.
outputs ([`DetrForSegmentation`]): outputs ([`DetrForSegmentation`]):
Raw outputs of the model. Raw outputs of the model.
threshold (`float`, *optional*): threshold (`float`, *optional*, defaults to 0.5):
The probability score threshold to keep predicted instance masks, defaults to 0.5. The probability score threshold to keep predicted instance masks.
overlap_mask_area_threshold (`float`, *optional*): overlap_mask_area_threshold (`float`, *optional*, defaults to 0.8):
The overlap mask area threshold to merge or discard small disconnected parts within each binary The overlap mask area threshold to merge or discard small disconnected parts within each binary
instance mask, defaults to 0.8. instance mask.
target_sizes (`List[Tuple]`, *optional*, defaults to `None`): target_sizes (`List[Tuple]`, *optional*):
List of length (batch_size), where each list item (`Tuple[int, int]]`) corresponds to the requested List of length (batch_size), where each list item (`Tuple[int, int]]`) corresponds to the requested
final size (height, width) of each prediction. If left to None, predictions will not be resized. final size (height, width) of each prediction. If left to None, predictions will not be resized.
return_coco_annotation (`bool`, *optional*, defaults to `False`): return_coco_annotation (`bool`, *optional*):
If set to `True`, segmentation maps are returned in COCO run-length encoding (RLE) format. Defaults to `False`. If set to `True`, segmentation maps are returned in COCO run-length encoding (RLE)
format.
Returns: Returns:
`List[Dict]`: A list of dictionaries, one per image, each dictionary containing two keys: `List[Dict]`: A list of dictionaries, one per image, each dictionary containing two keys:
- **segmentation** -- A tensor of shape `(height, width)` where each pixel represents a `segment_id` or - **segmentation** -- A tensor of shape `(height, width)` where each pixel represents a `segment_id` or
`List[List]` run-length encoding (RLE) of the segmentation map if return_coco_format is set to `True`. `List[List]` run-length encoding (RLE) of the segmentation map if return_coco_annotation is set to
- **segment_ids** -- A dictionary that maps segment ids to semantic class ids. `True`. Set to `None` if no mask if found above `threshold`.
- **segments_info** -- A dictionary that contains additional information on each segment.
- **id** -- An integer representing the `segment_id`. - **id** -- An integer representing the `segment_id`.
- **label_id** -- An integer representing the segment's label / semantic class id. - **label_id** -- An integer representing the label / semantic class id corresponding to `segment_id`.
- **score** -- Prediction score of segment with `segment_id`.
""" """
class_queries_logits = outputs.logits # [batch_size, num_queries, num_classes+1] class_queries_logits = outputs.logits # [batch_size, num_queries, num_classes+1]
masks_queries_logits = outputs.pred_masks # [batch_size, num_queries, height, width] masks_queries_logits = outputs.pred_masks # [batch_size, num_queries, height, width]
...@@ -1159,76 +1259,27 @@ class DetrFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionMixin): ...@@ -1159,76 +1259,27 @@ class DetrFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionMixin):
mask_probs[i], pred_scores[i], pred_labels[i], threshold, num_labels mask_probs[i], pred_scores[i], pred_labels[i], threshold, num_labels
) )
height, width = target_sizes[i][0], target_sizes[i][1] # No mask found
segmentation = torch.zeros((height, width), dtype=torch.int32, device=mask_probs_item.device) if mask_probs_item.shape[0] <= 0:
segments: List[Dict] = [] segmentation = None
segments: List[Dict] = []
object_detected = mask_probs_item.shape[0] > 0 continue
if object_detected: # Get segmentation map and segment information of batch item
# Resize mask to corresponding target_size target_size = target_sizes[i] if target_sizes is not None else None
if target_sizes is not None: segmentation, segments = compute_segments(
mask_probs_item = torch.nn.functional.interpolate( mask_probs_item,
mask_probs_item.unsqueeze(0), pred_scores_item,
size=target_sizes[i], pred_labels_item,
mode="bilinear", overlap_mask_area_threshold,
align_corners=False, target_size,
)[0] )
current_segment_id = 0
# Weigh each mask by its prediction score
mask_probs_item *= pred_scores_item.view(-1, 1, 1)
mask_labels_item = mask_probs_item.argmax(0) # [height, width]
# Keep track of instances of each class
stuff_memory_list: Dict[str, int] = {}
for k in range(pred_labels_item.shape[0]):
# Get the mask associated with the k class
pred_class = pred_labels_item[k].item()
mask_k = mask_labels_item == k
mask_k_area = mask_k.sum()
# Compute the area of all the stuff in query k
original_area = (mask_probs_item[k] >= 0.5).sum()
mask_exists = mask_k_area > 0 and original_area > 0
if mask_exists:
# Eliminate segments with mask area below threshold
area_ratio = mask_k_area / original_area
if not area_ratio.item() > overlap_mask_area_threshold:
continue
# Add corresponding class id
if pred_class in stuff_memory_list:
current_segment_id = stuff_memory_list[pred_class]
else:
current_segment_id += 1
# Add current object segment to final segmentation map
segmentation[mask_k] = current_segment_id
segments.append(
{
"id": current_segment_id,
"label_id": pred_class,
}
)
else:
segmentation -= 1
# Return segmentation map in run-length encoding (RLE) format # Return segmentation map in run-length encoding (RLE) format
if return_coco_annotation: if return_coco_annotation:
segment_ids = torch.unique(segmentation) segmentation = convert_segmentation_to_rle(segmentation)
run_length_encodings = [] results.append({"segmentation": segmentation, "segments_info": segments})
for idx in segment_ids:
mask = torch.where(segmentation == idx, 1, 0)
rle = binary_mask_to_rle(mask)
run_length_encodings.append(rle)
segmentation = run_length_encodings
results.append({"segmentation": segmentation, "segment_ids": segments})
return results return results
def post_process_panoptic_segmentation( def post_process_panoptic_segmentation(
...@@ -1237,7 +1288,7 @@ class DetrFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionMixin): ...@@ -1237,7 +1288,7 @@ class DetrFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionMixin):
threshold: float = 0.5, threshold: float = 0.5,
overlap_mask_area_threshold: float = 0.8, overlap_mask_area_threshold: float = 0.8,
label_ids_to_fuse: Optional[Set[int]] = None, label_ids_to_fuse: Optional[Set[int]] = None,
target_sizes: List[Tuple] = None, target_sizes: Optional[List[Tuple[int, int]]] = None,
) -> List[Dict]: ) -> List[Dict]:
""" """
Args: Args:
...@@ -1250,7 +1301,7 @@ class DetrFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionMixin): ...@@ -1250,7 +1301,7 @@ class DetrFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionMixin):
overlap_mask_area_threshold (`float`, *optional*, defaults to 0.8): overlap_mask_area_threshold (`float`, *optional*, defaults to 0.8):
The overlap mask area threshold to merge or discard small disconnected parts within each binary The overlap mask area threshold to merge or discard small disconnected parts within each binary
instance mask. instance mask.
label_ids_to_fuse (`Set[int]`, *optional*, defaults to `None`): label_ids_to_fuse (`Set[int]`, *optional*):
The labels in this state will have all their instances be fused together. For instance we could say The labels in this state will have all their instances be fused together. For instance we could say
there can only be one sky in an image, but several persons, so the label ID for sky would be in that there can only be one sky in an image, but several persons, so the label ID for sky would be in that
set, but not the one for person. set, but not the one for person.
...@@ -1260,13 +1311,15 @@ class DetrFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionMixin): ...@@ -1260,13 +1311,15 @@ class DetrFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionMixin):
resized. resized.
Returns: Returns:
`List[Dict]`: A list of dictionaries, one per image, each dictionary containing two keys: `List[Dict]`: A list of dictionaries, one per image, each dictionary containing two keys:
- **segmentation** -- a tensor of shape `(height, width)` where each pixel represents a `segment_id`. If - **segmentation** -- a tensor of shape `(height, width)` where each pixel represents a `segment_id` or
`target_sizes` is specified, segmentation is resized to the corresponding `target_sizes` entry. `None` if no mask if found above `threshold`. If `target_sizes` is specified, segmentation is resized to
- **segment_ids** -- A dictionary that maps segment ids to semantic class ids. the corresponding `target_sizes` entry.
- **id** -- An integer representing the `segment_id`. - **segments_info** -- A dictionary that contains additional information on each segment.
- **label_id** -- An integer representing the segment's label / semantic class id. - **id** -- an integer representing the `segment_id`.
- **label_id** -- An integer representing the label / semantic class id corresponding to `segment_id`.
- **was_fused** -- a boolean, `True` if `label_id` was in `label_ids_to_fuse`, `False` otherwise. - **was_fused** -- a boolean, `True` if `label_id` was in `label_ids_to_fuse`, `False` otherwise.
Multiple instances of the same class / label were fused and assigned a single `segment_id`. Multiple instances of the same class / label were fused and assigned a single `segment_id`.
- **score** -- Prediction score of segment with `segment_id`.
""" """
if label_ids_to_fuse is None: if label_ids_to_fuse is None:
...@@ -1292,67 +1345,22 @@ class DetrFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionMixin): ...@@ -1292,67 +1345,22 @@ class DetrFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionMixin):
mask_probs[i], pred_scores[i], pred_labels[i], threshold, num_labels mask_probs[i], pred_scores[i], pred_labels[i], threshold, num_labels
) )
height, width = target_sizes[i][0], target_sizes[i][1] # No mask found
segmentation = torch.zeros((height, width), dtype=torch.int32, device=mask_probs_item.device) if mask_probs_item.shape[0] <= 0:
segments: List[Dict] = [] segmentation = None
segments: List[Dict] = []
object_detected = mask_probs_item.shape[0] > 0 continue
if object_detected: # Get segmentation map and segment information of batch item
# Resize mask to corresponding target_size target_size = target_sizes[i] if target_sizes is not None else None
if target_sizes is not None: segmentation, segments = compute_segments(
mask_probs_item = torch.nn.functional.interpolate( mask_probs_item,
mask_probs_item.unsqueeze(0), pred_scores_item,
size=target_sizes[i], pred_labels_item,
mode="bilinear", overlap_mask_area_threshold,
align_corners=False, label_ids_to_fuse,
)[0] target_size,
)
current_segment_id = 0
# Weigh each mask by its prediction score
mask_probs_item *= pred_scores_item.view(-1, 1, 1)
mask_labels_item = mask_probs_item.argmax(0) # [height, width]
# Keep track of instances of each class
stuff_memory_list: Dict[str, int] = {}
for k in range(pred_labels_item.shape[0]):
pred_class = pred_labels_item[k].item()
should_fuse = pred_class in label_ids_to_fuse
# Get the mask associated with the k class
mask_k = mask_labels_item == k
mask_k_area = mask_k.sum()
# Compute the area of all the stuff in query k
original_area = (mask_probs_item[k] >= 0.5).sum()
mask_exists = mask_k_area > 0 and original_area > 0
if mask_exists:
# Eliminate disconnected tiny segments
area_ratio = mask_k_area / original_area
if not area_ratio.item() > overlap_mask_area_threshold:
continue
# Add corresponding class id
if pred_class in stuff_memory_list:
current_segment_id = stuff_memory_list[pred_class]
else:
current_segment_id += 1
# Add current object segment to final segmentation map
segmentation[mask_k] = current_segment_id
segments.append(
{
"id": current_segment_id,
"label_id": pred_class,
"was_fused": should_fuse,
}
)
if should_fuse:
stuff_memory_list[pred_class] = current_segment_id
else:
segmentation -= 1
results.append({"segmentation": segmentation, "segment_ids": segments}) results.append({"segmentation": segmentation, "segments_info": segments})
return results return results
...@@ -1605,12 +1605,12 @@ class DetrForSegmentation(DetrPreTrainedModel): ...@@ -1605,12 +1605,12 @@ class DetrForSegmentation(DetrPreTrainedModel):
>>> # Use the `post_process_panoptic_segmentation` method of `DetrFeatureExtractor` to retrieve post-processed panoptic segmentation maps >>> # Use the `post_process_panoptic_segmentation` method of `DetrFeatureExtractor` to retrieve post-processed panoptic segmentation maps
>>> # Segmentation results are returned as a list of dictionaries >>> # Segmentation results are returned as a list of dictionaries
>>> result = feature_extractor.post_process_panoptic_segmentation(outputs, processed_sizes) >>> result = feature_extractor.post_process_panoptic_segmentation(outputs, target_size=[(300, 500)])
>>> # A tensor of shape (height, width) where each value denotes a segment id >>> # A tensor of shape (height, width) where each value denotes a segment id
>>> panoptic_seg = result[0]["segmentation"] >>> panoptic_seg = result[0]["segmentation"]
>>> # Get mapping of segment ids to semantic class ids >>> # Get mapping of segment ids to semantic class ids
>>> panoptic_segments_info = result[0]["segment_ids"] >>> panoptic_segments_info = result[0]["segments_info"]
```""" ```"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict return_dict = return_dict if return_dict is not None else self.config.use_return_dict
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment