Unverified Commit 9e957066 authored by Alara Dirik's avatar Alara Dirik Committed by GitHub
Browse files

Add post_process_semantic_segmentation method to SegFormer (#19072)

* add post_process_semantic_segmentation method to SegformerFeatureExtractor
* add test for semantic segmentation post-processing
parent ef6741fe
...@@ -93,6 +93,7 @@ SegFormer's results on the segmentation datasets like ADE20k, refer to the [pape ...@@ -93,6 +93,7 @@ SegFormer's results on the segmentation datasets like ADE20k, refer to the [pape
[[autodoc]] SegformerFeatureExtractor [[autodoc]] SegformerFeatureExtractor
- __call__ - __call__
- post_process_semantic_segmentation
## SegformerModel ## SegformerModel
......
...@@ -14,7 +14,7 @@ ...@@ -14,7 +14,7 @@
# limitations under the License. # limitations under the License.
"""Feature extractor class for SegFormer.""" """Feature extractor class for SegFormer."""
from typing import Optional, Union from typing import List, Optional, Tuple, Union
import numpy as np import numpy as np
from PIL import Image from PIL import Image
...@@ -27,9 +27,12 @@ from ...image_utils import ( ...@@ -27,9 +27,12 @@ from ...image_utils import (
ImageInput, ImageInput,
is_torch_tensor, is_torch_tensor,
) )
from ...utils import TensorType, logging from ...utils import TensorType, is_torch_available, logging
if is_torch_available():
import torch
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
...@@ -211,3 +214,45 @@ class SegformerFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionMi ...@@ -211,3 +214,45 @@ class SegformerFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionMi
encoded_inputs = BatchFeature(data=data, tensor_type=return_tensors) encoded_inputs = BatchFeature(data=data, tensor_type=return_tensors)
return encoded_inputs return encoded_inputs
def post_process_semantic_segmentation(self, outputs, target_sizes: List[Tuple] = None):
"""
Converts the output of [`SegformerForSemanticSegmentation`] into semantic segmentation maps. Only supports
PyTorch.
Args:
outputs ([`SegformerForSemanticSegmentation`]):
Raw outputs of the model.
target_sizes (`List[Tuple]` of length `batch_size`, *optional*):
List of tuples corresponding to the requested final size (height, width) of each prediction. If left to
None, predictions will not be resized.
Returns:
semantic_segmentation: `List[torch.Tensor]` of length `batch_size`, where each item is a semantic
segmentation map of shape (height, width) corresponding to the target_sizes entry (if `target_sizes` is
specified). Each entry of each `torch.Tensor` correspond to a semantic class id.
"""
logits = outputs.logits
# Resize logits and compute semantic segmentation maps
if target_sizes is not None:
if len(logits) != len(target_sizes):
raise ValueError(
"Make sure that you pass in as many target sizes as the batch dimension of the logits"
)
if is_torch_tensor(target_sizes):
target_sizes = target_sizes.numpy()
semantic_segmentation = []
for idx in range(len(logits)):
resized_logits = torch.nn.functional.interpolate(
logits[idx].unsqueeze(dim=0), size=target_sizes[idx], mode="bilinear", align_corners=False
)
semantic_map = resized_logits[0].argmax(dim=0)
semantic_segmentation.append(semantic_map)
else:
semantic_segmentation = logits.argmax(dim=1)
semantic_segmentation = [semantic_segmentation[i] for i in range(semantic_segmentation.shape[0])]
return semantic_segmentation
...@@ -395,3 +395,30 @@ class SegformerModelIntegrationTest(unittest.TestCase): ...@@ -395,3 +395,30 @@ class SegformerModelIntegrationTest(unittest.TestCase):
] ]
).to(torch_device) ).to(torch_device)
self.assertTrue(torch.allclose(outputs.logits[0, :3, :3, :3], expected_slice, atol=1e-1)) self.assertTrue(torch.allclose(outputs.logits[0, :3, :3, :3], expected_slice, atol=1e-1))
@slow
def test_post_processing_semantic_segmentation(self):
# only resize + normalize
feature_extractor = SegformerFeatureExtractor(
image_scale=(512, 512), keep_ratio=False, align=False, do_random_crop=False
)
model = SegformerForSemanticSegmentation.from_pretrained("nvidia/segformer-b0-finetuned-ade-512-512").to(
torch_device
)
image = prepare_img()
encoded_inputs = feature_extractor(images=image, return_tensors="pt")
pixel_values = encoded_inputs.pixel_values.to(torch_device)
with torch.no_grad():
outputs = model(pixel_values)
outputs.logits = outputs.logits.detach().cpu()
segmentation = feature_extractor.post_process_semantic_segmentation(outputs=outputs, target_sizes=[(500, 300)])
expected_shape = torch.Size((500, 300))
self.assertEqual(segmentation[0].shape, expected_shape)
segmentation = feature_extractor.post_process_semantic_segmentation(outputs=outputs)
expected_shape = torch.Size((128, 128))
self.assertEqual(segmentation[0].shape, expected_shape)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment