"...git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "ac52084bf29ab02ee4cbc79d7330562b5df17df2"
Unverified Commit e7fdfc72 authored by Alara Dirik's avatar Alara Dirik Committed by GitHub
Browse files

Add post_process_semantic_segmentation method to DPTFeatureExtractor (#19107)

* add post-processing method for semantic segmentation

* add test for post-processing
parent da6a1b6c
...@@ -37,6 +37,7 @@ This model was contributed by [nielsr](https://huggingface.co/nielsr). The origi ...@@ -37,6 +37,7 @@ This model was contributed by [nielsr](https://huggingface.co/nielsr). The origi
[[autodoc]] DPTFeatureExtractor [[autodoc]] DPTFeatureExtractor
- __call__ - __call__
- post_process_semantic_segmentation
## DPTModel ## DPTModel
......
...@@ -14,13 +14,12 @@ ...@@ -14,13 +14,12 @@
# limitations under the License. # limitations under the License.
"""Feature extractor class for DPT.""" """Feature extractor class for DPT."""
from typing import Optional, Union from typing import List, Optional, Tuple, Union
import numpy as np import numpy as np
from PIL import Image from PIL import Image
from ...feature_extraction_utils import BatchFeature, FeatureExtractionMixin from ...feature_extraction_utils import BatchFeature, FeatureExtractionMixin
from ...file_utils import TensorType
from ...image_utils import ( from ...image_utils import (
IMAGENET_STANDARD_MEAN, IMAGENET_STANDARD_MEAN,
IMAGENET_STANDARD_STD, IMAGENET_STANDARD_STD,
...@@ -28,9 +27,12 @@ from ...image_utils import ( ...@@ -28,9 +27,12 @@ from ...image_utils import (
ImageInput, ImageInput,
is_torch_tensor, is_torch_tensor,
) )
from ...utils import logging from ...utils import TensorType, is_torch_available, logging
if is_torch_available():
import torch
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
...@@ -200,3 +202,44 @@ class DPTFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionMixin): ...@@ -200,3 +202,44 @@ class DPTFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionMixin):
encoded_inputs = BatchFeature(data=data, tensor_type=return_tensors) encoded_inputs = BatchFeature(data=data, tensor_type=return_tensors)
return encoded_inputs return encoded_inputs
def post_process_semantic_segmentation(self, outputs, target_sizes: List[Tuple] = None):
"""
Converts the output of [`DPTForSemanticSegmentation`] into semantic segmentation maps. Only supports PyTorch.
Args:
outputs ([`DPTForSemanticSegmentation`]):
Raw outputs of the model.
target_sizes (`List[Tuple]` of length `batch_size`, *optional*):
List of tuples corresponding to the requested final size (height, width) of each prediction. If left to
None, predictions will not be resized.
Returns:
semantic_segmentation: `List[torch.Tensor]` of length `batch_size`, where each item is a semantic
segmentation map of shape (height, width) corresponding to the target_sizes entry (if `target_sizes` is
specified). Each entry of each `torch.Tensor` correspond to a semantic class id.
"""
logits = outputs.logits
# Resize logits and compute semantic segmentation maps
if target_sizes is not None:
if len(logits) != len(target_sizes):
raise ValueError(
"Make sure that you pass in as many target sizes as the batch dimension of the logits"
)
if is_torch_tensor(target_sizes):
target_sizes = target_sizes.numpy()
semantic_segmentation = []
for idx in range(len(logits)):
resized_logits = torch.nn.functional.interpolate(
logits[idx].unsqueeze(dim=0), size=target_sizes[idx], mode="bilinear", align_corners=False
)
semantic_map = resized_logits[0].argmax(dim=0)
semantic_segmentation.append(semantic_map)
else:
semantic_segmentation = logits.argmax(dim=1)
semantic_segmentation = [semantic_segmentation[i] for i in range(semantic_segmentation.shape[0])]
return semantic_segmentation
...@@ -298,3 +298,24 @@ class DPTModelIntegrationTest(unittest.TestCase): ...@@ -298,3 +298,24 @@ class DPTModelIntegrationTest(unittest.TestCase):
).to(torch_device) ).to(torch_device)
self.assertTrue(torch.allclose(outputs.logits[0, 0, :3, :3], expected_slice, atol=1e-4)) self.assertTrue(torch.allclose(outputs.logits[0, 0, :3, :3], expected_slice, atol=1e-4))
def test_post_processing_semantic_segmentation(self):
feature_extractor = DPTFeatureExtractor.from_pretrained("Intel/dpt-large-ade")
model = DPTForSemanticSegmentation.from_pretrained("Intel/dpt-large-ade").to(torch_device)
image = prepare_img()
inputs = feature_extractor(images=image, return_tensors="pt").to(torch_device)
# forward pass
with torch.no_grad():
outputs = model(**inputs)
outputs.logits = outputs.logits.detach().cpu()
segmentation = feature_extractor.post_process_semantic_segmentation(outputs=outputs, target_sizes=[(500, 300)])
expected_shape = torch.Size((500, 300))
self.assertEqual(segmentation[0].shape, expected_shape)
segmentation = feature_extractor.post_process_semantic_segmentation(outputs=outputs)
expected_shape = torch.Size((480, 480))
self.assertEqual(segmentation[0].shape, expected_shape)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment