Unverified Commit d5606378 authored by Harisankar Babu's avatar Harisankar Babu Committed by GitHub
Browse files

Optionally preprocess segmentation maps for MobileViT (#28420)

* optionally preprocess segmentation maps for mobilevit

* changed pretrained model name to that of segmentation model

* removed voc-deeplabv3 from model archive list

* added preprocess_image and preprocess_mask methods for processing images and segmentation masks respectively

* added tests for segmentation masks based on segformer feature extractor

* use crop_size instead of size

* reverting to initial model
parent 95091e15
...@@ -19,12 +19,7 @@ from typing import Dict, List, Optional, Tuple, Union ...@@ -19,12 +19,7 @@ from typing import Dict, List, Optional, Tuple, Union
import numpy as np import numpy as np
from ...image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict from ...image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict
from ...image_transforms import ( from ...image_transforms import flip_channel_order, get_resize_output_image_size, resize, to_channel_dimension_format
flip_channel_order,
get_resize_output_image_size,
resize,
to_channel_dimension_format,
)
from ...image_utils import ( from ...image_utils import (
ChannelDimension, ChannelDimension,
ImageInput, ImageInput,
...@@ -178,9 +173,126 @@ class MobileViTImageProcessor(BaseImageProcessor): ...@@ -178,9 +173,126 @@ class MobileViTImageProcessor(BaseImageProcessor):
""" """
return flip_channel_order(image, data_format=data_format, input_data_format=input_data_format) return flip_channel_order(image, data_format=data_format, input_data_format=input_data_format)
def __call__(self, images, segmentation_maps=None, **kwargs):
"""
Preprocesses a batch of images and optionally segmentation maps.
Overrides the `__call__` method of the `Preprocessor` class so that both images and segmentation maps can be
passed in as positional arguments.
"""
return super().__call__(images, segmentation_maps=segmentation_maps, **kwargs)
def _preprocess(
self,
image: ImageInput,
do_resize: bool,
do_rescale: bool,
do_center_crop: bool,
do_flip_channel_order: bool,
size: Optional[Dict[str, int]] = None,
resample: PILImageResampling = None,
rescale_factor: Optional[float] = None,
crop_size: Optional[Dict[str, int]] = None,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
):
if do_resize:
image = self.resize(image=image, size=size, resample=resample, input_data_format=input_data_format)
if do_rescale:
image = self.rescale(image=image, scale=rescale_factor, input_data_format=input_data_format)
if do_center_crop:
image = self.center_crop(image=image, size=crop_size, input_data_format=input_data_format)
if do_flip_channel_order:
image = self.flip_channel_order(image, input_data_format=input_data_format)
return image
def _preprocess_image(
self,
image: ImageInput,
do_resize: bool = None,
size: Dict[str, int] = None,
resample: PILImageResampling = None,
do_rescale: bool = None,
rescale_factor: float = None,
do_center_crop: bool = None,
crop_size: Dict[str, int] = None,
do_flip_channel_order: bool = None,
data_format: Optional[Union[str, ChannelDimension]] = None,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
) -> np.ndarray:
"""Preprocesses a single image."""
# All transformations expect numpy arrays.
image = to_numpy_array(image)
if is_scaled_image(image) and do_rescale:
logger.warning_once(
"It looks like you are trying to rescale already rescaled images. If the input"
" images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again."
)
if input_data_format is None:
input_data_format = infer_channel_dimension_format(image)
image = self._preprocess(
image=image,
do_resize=do_resize,
size=size,
resample=resample,
do_rescale=do_rescale,
rescale_factor=rescale_factor,
do_center_crop=do_center_crop,
crop_size=crop_size,
do_flip_channel_order=do_flip_channel_order,
input_data_format=input_data_format,
)
image = to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format)
return image
def _preprocess_mask(
self,
segmentation_map: ImageInput,
do_resize: bool = None,
size: Dict[str, int] = None,
do_center_crop: bool = None,
crop_size: Dict[str, int] = None,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
) -> np.ndarray:
"""Preprocesses a single mask."""
segmentation_map = to_numpy_array(segmentation_map)
# Add channel dimension if missing - needed for certain transformations
if segmentation_map.ndim == 2:
added_channel_dim = True
segmentation_map = segmentation_map[None, ...]
input_data_format = ChannelDimension.FIRST
else:
added_channel_dim = False
if input_data_format is None:
input_data_format = infer_channel_dimension_format(segmentation_map, num_channels=1)
segmentation_map = self._preprocess(
image=segmentation_map,
do_resize=do_resize,
size=size,
resample=PILImageResampling.NEAREST,
do_rescale=False,
do_center_crop=do_center_crop,
crop_size=crop_size,
do_flip_channel_order=False,
input_data_format=input_data_format,
)
# Remove extra channel dimension if added for processing
if added_channel_dim:
segmentation_map = segmentation_map.squeeze(0)
segmentation_map = segmentation_map.astype(np.int64)
return segmentation_map
def preprocess( def preprocess(
self, self,
images: ImageInput, images: ImageInput,
segmentation_maps: Optional[ImageInput] = None,
do_resize: bool = None, do_resize: bool = None,
size: Dict[str, int] = None, size: Dict[str, int] = None,
resample: PILImageResampling = None, resample: PILImageResampling = None,
...@@ -201,6 +313,8 @@ class MobileViTImageProcessor(BaseImageProcessor): ...@@ -201,6 +313,8 @@ class MobileViTImageProcessor(BaseImageProcessor):
images (`ImageInput`): images (`ImageInput`):
Image to preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255. If Image to preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255. If
passing in images with pixel values between 0 and 1, set `do_rescale=False`. passing in images with pixel values between 0 and 1, set `do_rescale=False`.
segmentation_maps (`ImageInput`, *optional*):
Segmentation map to preprocess.
do_resize (`bool`, *optional*, defaults to `self.do_resize`): do_resize (`bool`, *optional*, defaults to `self.do_resize`):
Whether to resize the image. Whether to resize the image.
size (`Dict[str, int]`, *optional*, defaults to `self.size`): size (`Dict[str, int]`, *optional*, defaults to `self.size`):
...@@ -251,6 +365,8 @@ class MobileViTImageProcessor(BaseImageProcessor): ...@@ -251,6 +365,8 @@ class MobileViTImageProcessor(BaseImageProcessor):
crop_size = get_size_dict(crop_size, param_name="crop_size") crop_size = get_size_dict(crop_size, param_name="crop_size")
images = make_list_of_images(images) images = make_list_of_images(images)
if segmentation_maps is not None:
segmentation_maps = make_list_of_images(segmentation_maps, expected_ndims=2)
if not valid_images(images): if not valid_images(images):
raise ValueError( raise ValueError(
...@@ -258,6 +374,12 @@ class MobileViTImageProcessor(BaseImageProcessor): ...@@ -258,6 +374,12 @@ class MobileViTImageProcessor(BaseImageProcessor):
"torch.Tensor, tf.Tensor or jax.ndarray." "torch.Tensor, tf.Tensor or jax.ndarray."
) )
if segmentation_maps is not None and not valid_images(segmentation_maps):
raise ValueError(
"Invalid segmentation map type. Must be of type PIL.Image.Image, numpy.ndarray, "
"torch.Tensor, tf.Tensor or jax.ndarray."
)
if do_resize and size is None: if do_resize and size is None:
raise ValueError("Size must be specified if do_resize is True.") raise ValueError("Size must be specified if do_resize is True.")
...@@ -267,45 +389,40 @@ class MobileViTImageProcessor(BaseImageProcessor): ...@@ -267,45 +389,40 @@ class MobileViTImageProcessor(BaseImageProcessor):
if do_center_crop and crop_size is None: if do_center_crop and crop_size is None:
raise ValueError("Crop size must be specified if do_center_crop is True.") raise ValueError("Crop size must be specified if do_center_crop is True.")
# All transformations expect numpy arrays. images = [
images = [to_numpy_array(image) for image in images] self._preprocess_image(
image=img,
if is_scaled_image(images[0]) and do_rescale: do_resize=do_resize,
logger.warning_once( size=size,
"It looks like you are trying to rescale already rescaled images. If the input" resample=resample,
" images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again." do_rescale=do_rescale,
rescale_factor=rescale_factor,
do_center_crop=do_center_crop,
crop_size=crop_size,
do_flip_channel_order=do_flip_channel_order,
data_format=data_format,
input_data_format=input_data_format,
) )
for img in images
]
if input_data_format is None: data = {"pixel_values": images}
# We assume that all images have the same channel dimension format.
input_data_format = infer_channel_dimension_format(images[0])
if do_resize:
images = [
self.resize(image=image, size=size, resample=resample, input_data_format=input_data_format)
for image in images
]
if do_center_crop:
images = [
self.center_crop(image=image, size=crop_size, input_data_format=input_data_format) for image in images
]
if do_rescale: if segmentation_maps is not None:
images = [ segmentation_maps = [
self.rescale(image=image, scale=rescale_factor, input_data_format=input_data_format) self._preprocess_mask(
for image in images segmentation_map=segmentation_map,
do_resize=do_resize,
size=size,
do_center_crop=do_center_crop,
crop_size=crop_size,
input_data_format=input_data_format,
)
for segmentation_map in segmentation_maps
] ]
# the pretrained checkpoints assume images are BGR, not RGB data["labels"] = segmentation_maps
if do_flip_channel_order:
images = [self.flip_channel_order(image=image, input_data_format=input_data_format) for image in images]
images = [
to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format) for image in images
]
data = {"pixel_values": images}
return BatchFeature(data=data, tensor_type=return_tensors) return BatchFeature(data=data, tensor_type=return_tensors)
# Copied from transformers.models.beit.image_processing_beit.BeitImageProcessor.post_process_semantic_segmentation with Beit->MobileViT # Copied from transformers.models.beit.image_processing_beit.BeitImageProcessor.post_process_semantic_segmentation with Beit->MobileViT
......
...@@ -16,13 +16,20 @@ ...@@ -16,13 +16,20 @@
import unittest import unittest
from datasets import load_dataset
from transformers.testing_utils import require_torch, require_vision from transformers.testing_utils import require_torch, require_vision
from transformers.utils import is_vision_available from transformers.utils import is_torch_available, is_vision_available
from ...test_image_processing_common import ImageProcessingTestMixin, prepare_image_inputs from ...test_image_processing_common import ImageProcessingTestMixin, prepare_image_inputs
if is_torch_available():
import torch
if is_vision_available(): if is_vision_available():
from PIL import Image
from transformers import MobileViTImageProcessor from transformers import MobileViTImageProcessor
...@@ -79,6 +86,26 @@ class MobileViTImageProcessingTester(unittest.TestCase): ...@@ -79,6 +86,26 @@ class MobileViTImageProcessingTester(unittest.TestCase):
) )
def prepare_semantic_single_inputs():
dataset = load_dataset("hf-internal-testing/fixtures_ade20k", split="test")
image = Image.open(dataset[0]["file"])
map = Image.open(dataset[1]["file"])
return image, map
def prepare_semantic_batch_inputs():
dataset = load_dataset("hf-internal-testing/fixtures_ade20k", split="test")
image1 = Image.open(dataset[0]["file"])
map1 = Image.open(dataset[1]["file"])
image2 = Image.open(dataset[2]["file"])
map2 = Image.open(dataset[3]["file"])
return [image1, image2], [map1, map2]
@require_torch @require_torch
@require_vision @require_vision
class MobileViTImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase): class MobileViTImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase):
...@@ -107,3 +134,109 @@ class MobileViTImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase): ...@@ -107,3 +134,109 @@ class MobileViTImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase):
image_processor = self.image_processing_class.from_dict(self.image_processor_dict, size=42, crop_size=84) image_processor = self.image_processing_class.from_dict(self.image_processor_dict, size=42, crop_size=84)
self.assertEqual(image_processor.size, {"shortest_edge": 42}) self.assertEqual(image_processor.size, {"shortest_edge": 42})
self.assertEqual(image_processor.crop_size, {"height": 84, "width": 84}) self.assertEqual(image_processor.crop_size, {"height": 84, "width": 84})
def test_call_segmentation_maps(self):
# Initialize image_processing
image_processing = self.image_processing_class(**self.image_processor_dict)
# create random PyTorch tensors
image_inputs = self.image_processor_tester.prepare_image_inputs(equal_resolution=False, torchify=True)
maps = []
for image in image_inputs:
self.assertIsInstance(image, torch.Tensor)
maps.append(torch.zeros(image.shape[-2:]).long())
# Test not batched input
encoding = image_processing(image_inputs[0], maps[0], return_tensors="pt")
self.assertEqual(
encoding["pixel_values"].shape,
(
1,
self.image_processor_tester.num_channels,
self.image_processor_tester.crop_size["height"],
self.image_processor_tester.crop_size["width"],
),
)
self.assertEqual(
encoding["labels"].shape,
(
1,
self.image_processor_tester.crop_size["height"],
self.image_processor_tester.crop_size["width"],
),
)
self.assertEqual(encoding["labels"].dtype, torch.long)
self.assertTrue(encoding["labels"].min().item() >= 0)
self.assertTrue(encoding["labels"].max().item() <= 255)
# Test batched
encoding = image_processing(image_inputs, maps, return_tensors="pt")
self.assertEqual(
encoding["pixel_values"].shape,
(
self.image_processor_tester.batch_size,
self.image_processor_tester.num_channels,
self.image_processor_tester.crop_size["height"],
self.image_processor_tester.crop_size["width"],
),
)
self.assertEqual(
encoding["labels"].shape,
(
self.image_processor_tester.batch_size,
self.image_processor_tester.crop_size["height"],
self.image_processor_tester.crop_size["width"],
),
)
self.assertEqual(encoding["labels"].dtype, torch.long)
self.assertTrue(encoding["labels"].min().item() >= 0)
self.assertTrue(encoding["labels"].max().item() <= 255)
# Test not batched input (PIL images)
image, segmentation_map = prepare_semantic_single_inputs()
encoding = image_processing(image, segmentation_map, return_tensors="pt")
self.assertEqual(
encoding["pixel_values"].shape,
(
1,
self.image_processor_tester.num_channels,
self.image_processor_tester.crop_size["height"],
self.image_processor_tester.crop_size["width"],
),
)
self.assertEqual(
encoding["labels"].shape,
(
1,
self.image_processor_tester.crop_size["height"],
self.image_processor_tester.crop_size["width"],
),
)
self.assertEqual(encoding["labels"].dtype, torch.long)
self.assertTrue(encoding["labels"].min().item() >= 0)
self.assertTrue(encoding["labels"].max().item() <= 255)
# Test batched input (PIL images)
images, segmentation_maps = prepare_semantic_batch_inputs()
encoding = image_processing(images, segmentation_maps, return_tensors="pt")
self.assertEqual(
encoding["pixel_values"].shape,
(
2,
self.image_processor_tester.num_channels,
self.image_processor_tester.crop_size["height"],
self.image_processor_tester.crop_size["width"],
),
)
self.assertEqual(
encoding["labels"].shape,
(
2,
self.image_processor_tester.crop_size["height"],
self.image_processor_tester.crop_size["width"],
),
)
self.assertEqual(encoding["labels"].dtype, torch.long)
self.assertTrue(encoding["labels"].min().item() >= 0)
self.assertTrue(encoding["labels"].max().item() <= 255)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment