Unverified Commit c81ebd1c authored by Alara Dirik's avatar Alara Dirik Committed by GitHub
Browse files

Beit postprocessing (#19099)

* add post_process_semantic_segmentation method to BeiTFeatureExtractor
parent 261301d3
...@@ -82,6 +82,7 @@ contributed by [kamalkraj](https://huggingface.co/kamalkraj). The original code ...@@ -82,6 +82,7 @@ contributed by [kamalkraj](https://huggingface.co/kamalkraj). The original code
[[autodoc]] BeitFeatureExtractor [[autodoc]] BeitFeatureExtractor
- __call__ - __call__
- post_process_semantic_segmentation
## BeitModel ## BeitModel
......
...@@ -14,7 +14,7 @@ ...@@ -14,7 +14,7 @@
# limitations under the License. # limitations under the License.
"""Feature extractor class for BEiT.""" """Feature extractor class for BEiT."""
from typing import Optional, Union from typing import List, Optional, Tuple, Union
import numpy as np import numpy as np
from PIL import Image from PIL import Image
...@@ -27,9 +27,12 @@ from ...image_utils import ( ...@@ -27,9 +27,12 @@ from ...image_utils import (
ImageInput, ImageInput,
is_torch_tensor, is_torch_tensor,
) )
from ...utils import TensorType, logging from ...utils import TensorType, is_torch_available, logging
if is_torch_available():
import torch
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
...@@ -222,3 +225,44 @@ class BeitFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionMixin): ...@@ -222,3 +225,44 @@ class BeitFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionMixin):
encoded_inputs = BatchFeature(data=data, tensor_type=return_tensors) encoded_inputs = BatchFeature(data=data, tensor_type=return_tensors)
return encoded_inputs return encoded_inputs
def post_process_semantic_segmentation(self, outputs, target_sizes: Union[TensorType, List[Tuple]] = None):
"""
Converts the output of [`BeitForSemanticSegmentation`] into semantic segmentation maps. Only supports PyTorch.
Args:
outputs ([`BeitForSemanticSegmentation`]):
Raw outputs of the model.
target_sizes (`torch.Tensor` of shape `(batch_size, 2)` or `List[Tuple]` of length `batch_size`, *optional*):
Torch Tensor (or list) corresponding to the requested final size (h, w) of each prediction. If left to
None, predictions will not be resized.
Returns:
semantic_segmentation: `torch.Tensor` of shape `(batch_size, 2)` or `List[torch.Tensor]` of length
`batch_size`, where each item is a semantic segmentation map of of the corresponding target_sizes entry (if
`target_sizes` is specified). Each entry of each `torch.Tensor` correspond to a semantic class id.
"""
logits = outputs.logits
if len(logits) != len(target_sizes):
raise ValueError("Make sure that you pass in as many target sizes as the batch dimension of the logits")
if target_sizes is not None and target_sizes.shape[1] != 2:
raise ValueError("Each element of target_sizes must contain the size (h, w) of each image of the batch")
semantic_segmentation = logits.argmax(dim=1)
# Resize semantic segmentation maps
if target_sizes is not None:
if is_torch_tensor(target_sizes):
target_sizes = target_sizes.numpy()
resized_maps = []
semantic_segmentation = semantic_segmentation.numpy()
for idx in range(len(semantic_segmentation)):
resized = self.resize(image=semantic_segmentation[idx], size=target_sizes[idx])
resized_maps.append(resized)
semantic_segmentation = [torch.Tensor(np.array(image)) for image in resized_maps]
return semantic_segmentation
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment