"...git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "79d28e80b6f5a37c3f6bacf3fd708963c58b68fb"
Unverified Commit 6d4cabda authored by Eduardo Pacheco's avatar Eduardo Pacheco Committed by GitHub
Browse files

[SegGPT] Fix seggpt image processor (#29550)

* Fixed SegGptImageProcessor to handle 2D and 3D prompt mask inputs

* Added new test to check prompt mask equivalence

* New proposal

* Better proposal

* Removed unnecessary method

* Updated seggpt docs

* Introduced do_convert_rgb

* nits
parent c793b26f
...@@ -26,7 +26,8 @@ The abstract from the paper is the following: ...@@ -26,7 +26,8 @@ The abstract from the paper is the following:
Tips: Tips:
- One can use [`SegGptImageProcessor`] to prepare image input, prompt and mask to the model. - One can use [`SegGptImageProcessor`] to prepare image input, prompt and mask to the model.
- It's highly advisable to pass `num_labels` (not considering background) during preprocessing and postprocessing with [`SegGptImageProcessor`] for your use case. - One can either use segmentation maps or RGB images as prompt masks. If using the latter make sure to set `do_convert_rgb=False` in the `preprocess` method.
- It's highly advisable to pass `num_labels` when using `segmetantion_maps` (not considering background) during preprocessing and postprocessing with [`SegGptImageProcessor`] for your use case.
- When doing inference with [`SegGptForImageSegmentation`] if your `batch_size` is greater than 1 you can use feature ensemble across your images by passing `feature_ensemble=True` in the forward method. - When doing inference with [`SegGptForImageSegmentation`] if your `batch_size` is greater than 1 you can use feature ensemble across your images by passing `feature_ensemble=True` in the forward method.
Here's how to use the model for one-shot semantic segmentation: Here's how to use the model for one-shot semantic segmentation:
...@@ -53,7 +54,7 @@ mask_prompt = ds[29]["label"] ...@@ -53,7 +54,7 @@ mask_prompt = ds[29]["label"]
inputs = image_processor( inputs = image_processor(
images=image_input, images=image_input,
prompt_images=image_prompt, prompt_images=image_prompt,
prompt_masks=mask_prompt, segmentation_maps=mask_prompt,
num_labels=num_labels, num_labels=num_labels,
return_tensors="pt" return_tensors="pt"
) )
......
...@@ -26,19 +26,21 @@ from ...image_utils import ( ...@@ -26,19 +26,21 @@ from ...image_utils import (
ChannelDimension, ChannelDimension,
ImageInput, ImageInput,
PILImageResampling, PILImageResampling,
get_channel_dimension_axis,
infer_channel_dimension_format, infer_channel_dimension_format,
is_scaled_image, is_scaled_image,
make_list_of_images, make_list_of_images,
to_numpy_array, to_numpy_array,
valid_images, valid_images,
) )
from ...utils import TensorType, is_torch_available, logging, requires_backends from ...utils import TensorType, is_torch_available, is_vision_available, logging, requires_backends
if is_torch_available(): if is_torch_available():
import torch import torch
if is_vision_available():
pass
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
...@@ -65,29 +67,10 @@ def build_palette(num_labels: int) -> List[Tuple[int, int]]: ...@@ -65,29 +67,10 @@ def build_palette(num_labels: int) -> List[Tuple[int, int]]:
return color_list return color_list
def get_num_channels(image: np.ndarray, input_data_format: ChannelDimension) -> int:
if image.ndim == 2:
return 0
channel_idx = get_channel_dimension_axis(image, input_data_format)
return image.shape[channel_idx]
def mask_to_rgb( def mask_to_rgb(
mask: np.ndarray, mask: np.ndarray, palette: Optional[List[Tuple[int, int]]] = None, data_format: Optional[ChannelDimension] = None
palette: Optional[List[Tuple[int, int]]] = None,
input_data_format: Optional[ChannelDimension] = None,
data_format: Optional[ChannelDimension] = None,
) -> np.ndarray: ) -> np.ndarray:
if input_data_format is None and mask.ndim > 2: data_format = data_format if data_format is not None else ChannelDimension.FIRST
input_data_format = infer_channel_dimension_format(mask)
data_format = data_format if data_format is not None else input_data_format
num_channels = get_num_channels(mask, input_data_format)
if num_channels == 3:
return to_channel_dimension_format(mask, data_format, input_data_format) if data_format is not None else mask
if palette is not None: if palette is not None:
height, width = mask.shape height, width = mask.shape
...@@ -109,9 +92,7 @@ def mask_to_rgb( ...@@ -109,9 +92,7 @@ def mask_to_rgb(
else: else:
rgb_mask = np.repeat(mask[None, ...], 3, axis=0) rgb_mask = np.repeat(mask[None, ...], 3, axis=0)
return ( return to_channel_dimension_format(rgb_mask, data_format)
to_channel_dimension_format(rgb_mask, data_format, input_data_format) if data_format is not None else rgb_mask
)
class SegGptImageProcessor(BaseImageProcessor): class SegGptImageProcessor(BaseImageProcessor):
...@@ -143,6 +124,9 @@ class SegGptImageProcessor(BaseImageProcessor): ...@@ -143,6 +124,9 @@ class SegGptImageProcessor(BaseImageProcessor):
image_std (`float` or `List[float]`, *optional*, defaults to `IMAGENET_DEFAULT_STD`): image_std (`float` or `List[float]`, *optional*, defaults to `IMAGENET_DEFAULT_STD`):
Standard deviation to use if normalizing the image. This is a float or list of floats the length of the Standard deviation to use if normalizing the image. This is a float or list of floats the length of the
number of channels in the image. Can be overridden by the `image_std` parameter in the `preprocess` method. number of channels in the image. Can be overridden by the `image_std` parameter in the `preprocess` method.
do_convert_rgb (`bool`, *optional*, defaults to `True`):
Whether to convert the prompt mask to RGB format. Can be overridden by the `do_convert_rgb` parameter in the
`preprocess` method.
""" """
model_input_names = ["pixel_values"] model_input_names = ["pixel_values"]
...@@ -157,6 +141,7 @@ class SegGptImageProcessor(BaseImageProcessor): ...@@ -157,6 +141,7 @@ class SegGptImageProcessor(BaseImageProcessor):
do_normalize: bool = True, do_normalize: bool = True,
image_mean: Optional[Union[float, List[float]]] = None, image_mean: Optional[Union[float, List[float]]] = None,
image_std: Optional[Union[float, List[float]]] = None, image_std: Optional[Union[float, List[float]]] = None,
do_convert_rgb: bool = True,
**kwargs, **kwargs,
) -> None: ) -> None:
super().__init__(**kwargs) super().__init__(**kwargs)
...@@ -170,6 +155,7 @@ class SegGptImageProcessor(BaseImageProcessor): ...@@ -170,6 +155,7 @@ class SegGptImageProcessor(BaseImageProcessor):
self.rescale_factor = rescale_factor self.rescale_factor = rescale_factor
self.image_mean = image_mean if image_mean is not None else IMAGENET_DEFAULT_MEAN self.image_mean = image_mean if image_mean is not None else IMAGENET_DEFAULT_MEAN
self.image_std = image_std if image_std is not None else IMAGENET_DEFAULT_STD self.image_std = image_std if image_std is not None else IMAGENET_DEFAULT_STD
self.do_convert_rgb = do_convert_rgb
def get_palette(self, num_labels: int) -> List[Tuple[int, int]]: def get_palette(self, num_labels: int) -> List[Tuple[int, int]]:
"""Build a palette to map the prompt mask from a single channel to a 3 channel RGB. """Build a palette to map the prompt mask from a single channel to a 3 channel RGB.
...@@ -188,13 +174,12 @@ class SegGptImageProcessor(BaseImageProcessor): ...@@ -188,13 +174,12 @@ class SegGptImageProcessor(BaseImageProcessor):
image: np.ndarray, image: np.ndarray,
palette: Optional[List[Tuple[int, int]]] = None, palette: Optional[List[Tuple[int, int]]] = None,
data_format: Optional[Union[str, ChannelDimension]] = None, data_format: Optional[Union[str, ChannelDimension]] = None,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
) -> np.ndarray: ) -> np.ndarray:
"""Convert a mask to RGB format. """Converts a segmentation map to RGB format.
Args: Args:
image (`np.ndarray`): image (`np.ndarray`):
Mask to convert to RGB format. If the mask is already in RGB format, it will be passed through. Segmentation map with dimensions (height, width) where pixel values represent the class index.
palette (`List[Tuple[int, int]]`, *optional*, defaults to `None`): palette (`List[Tuple[int, int]]`, *optional*, defaults to `None`):
Palette to use to convert the mask to RGB format. If unset, the mask is duplicated across the channel Palette to use to convert the mask to RGB format. If unset, the mask is duplicated across the channel
dimension. dimension.
...@@ -203,21 +188,11 @@ class SegGptImageProcessor(BaseImageProcessor): ...@@ -203,21 +188,11 @@ class SegGptImageProcessor(BaseImageProcessor):
image is used. Can be one of: image is used. Can be one of:
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
input_data_format (`ChannelDimension` or `str`, *optional*):
The channel dimension format for the input image. If unset, the channel dimension format is inferred
from the input image. Can be one of:
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
Returns: Returns:
`np.ndarray`: The mask in RGB format. `np.ndarray`: The mask in RGB format.
""" """
return mask_to_rgb( return mask_to_rgb(image, palette=palette, data_format=data_format)
image,
palette=palette,
data_format=data_format,
input_data_format=input_data_format,
)
# Copied from transformers.models.vit.image_processing_vit.ViTImageProcessor.resize with PILImageResampling.BILINEAR->PILImageResampling.BICUBIC # Copied from transformers.models.vit.image_processing_vit.ViTImageProcessor.resize with PILImageResampling.BILINEAR->PILImageResampling.BICUBIC
def resize( def resize(
...@@ -271,7 +246,6 @@ class SegGptImageProcessor(BaseImageProcessor): ...@@ -271,7 +246,6 @@ class SegGptImageProcessor(BaseImageProcessor):
def _preprocess_step( def _preprocess_step(
self, self,
images: ImageInput, images: ImageInput,
is_mask: bool = False,
do_resize: Optional[bool] = None, do_resize: Optional[bool] = None,
size: Dict[str, int] = None, size: Dict[str, int] = None,
resample: PILImageResampling = None, resample: PILImageResampling = None,
...@@ -282,6 +256,7 @@ class SegGptImageProcessor(BaseImageProcessor): ...@@ -282,6 +256,7 @@ class SegGptImageProcessor(BaseImageProcessor):
image_std: Optional[Union[float, List[float]]] = None, image_std: Optional[Union[float, List[float]]] = None,
data_format: Union[str, ChannelDimension] = ChannelDimension.FIRST, data_format: Union[str, ChannelDimension] = ChannelDimension.FIRST,
input_data_format: Optional[Union[str, ChannelDimension]] = None, input_data_format: Optional[Union[str, ChannelDimension]] = None,
do_convert_rgb: Optional[bool] = None,
num_labels: Optional[int] = None, num_labels: Optional[int] = None,
**kwargs, **kwargs,
): ):
...@@ -292,9 +267,6 @@ class SegGptImageProcessor(BaseImageProcessor): ...@@ -292,9 +267,6 @@ class SegGptImageProcessor(BaseImageProcessor):
images (`ImageInput`): images (`ImageInput`):
Image to _preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255. If Image to _preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255. If
passing in images with pixel values between 0 and 1, set `do_rescale=False`. passing in images with pixel values between 0 and 1, set `do_rescale=False`.
is_mask (`bool`, *optional*, defaults to `False`):
Whether the image is a mask. If True, the image is converted to RGB using the palette if
`self.num_labels` is specified otherwise RGB is achieved by duplicating the channel.
do_resize (`bool`, *optional*, defaults to `self.do_resize`): do_resize (`bool`, *optional*, defaults to `self.do_resize`):
Whether to resize the image. Whether to resize the image.
size (`Dict[str, int]`, *optional*, defaults to `self.size`): size (`Dict[str, int]`, *optional*, defaults to `self.size`):
...@@ -331,6 +303,10 @@ class SegGptImageProcessor(BaseImageProcessor): ...@@ -331,6 +303,10 @@ class SegGptImageProcessor(BaseImageProcessor):
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
- `"none"` or `ChannelDimension.NONE`: image in (height, width) format. - `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
do_convert_rgb (`bool`, *optional*, defaults to `self.do_convert_rgb`):
Whether to convert the prompt mask to RGB format. If `num_labels` is specified, a palette will be built
to map the prompt mask from a single channel to a 3 channel RGB. If unset, the prompt mask is duplicated
across the channel dimension. Must be set to `False` if the prompt mask is already in RGB format.
num_labels: (`int`, *optional*): num_labels: (`int`, *optional*):
Number of classes in the segmentation task (excluding the background). If specified, a palette will be Number of classes in the segmentation task (excluding the background). If specified, a palette will be
built, assuming that class_idx 0 is the background, to map the prompt mask from a single class_idx built, assuming that class_idx 0 is the background, to map the prompt mask from a single class_idx
...@@ -340,6 +316,7 @@ class SegGptImageProcessor(BaseImageProcessor): ...@@ -340,6 +316,7 @@ class SegGptImageProcessor(BaseImageProcessor):
do_resize = do_resize if do_resize is not None else self.do_resize do_resize = do_resize if do_resize is not None else self.do_resize
do_rescale = do_rescale if do_rescale is not None else self.do_rescale do_rescale = do_rescale if do_rescale is not None else self.do_rescale
do_normalize = do_normalize if do_normalize is not None else self.do_normalize do_normalize = do_normalize if do_normalize is not None else self.do_normalize
do_convert_rgb = do_convert_rgb if do_convert_rgb is not None else self.do_convert_rgb
resample = resample if resample is not None else self.resample resample = resample if resample is not None else self.resample
rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor
image_mean = image_mean if image_mean is not None else self.image_mean image_mean = image_mean if image_mean is not None else self.image_mean
...@@ -348,7 +325,8 @@ class SegGptImageProcessor(BaseImageProcessor): ...@@ -348,7 +325,8 @@ class SegGptImageProcessor(BaseImageProcessor):
size = size if size is not None else self.size size = size if size is not None else self.size
size_dict = get_size_dict(size) size_dict = get_size_dict(size)
images = make_list_of_images(images) # If segmentation map is passed we expect 2D images
images = make_list_of_images(images, expected_ndims=2 if do_convert_rgb else 3)
if not valid_images(images): if not valid_images(images):
raise ValueError( raise ValueError(
...@@ -374,11 +352,11 @@ class SegGptImageProcessor(BaseImageProcessor): ...@@ -374,11 +352,11 @@ class SegGptImageProcessor(BaseImageProcessor):
" images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again." " images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again."
) )
if input_data_format is None and not is_mask: if input_data_format is None and not do_convert_rgb:
# We assume that all images have the same channel dimension format. # We assume that all images have the same channel dimension format.
input_data_format = infer_channel_dimension_format(images[0]) input_data_format = infer_channel_dimension_format(images[0])
if is_mask: if do_convert_rgb:
palette = self.get_palette(num_labels) if num_labels is not None else None palette = self.get_palette(num_labels) if num_labels is not None else None
# Since this is the input for the next transformations its format should be the same as the input_data_format # Since this is the input for the next transformations its format should be the same as the input_data_format
images = [ images = [
...@@ -423,6 +401,7 @@ class SegGptImageProcessor(BaseImageProcessor): ...@@ -423,6 +401,7 @@ class SegGptImageProcessor(BaseImageProcessor):
do_normalize: Optional[bool] = None, do_normalize: Optional[bool] = None,
image_mean: Optional[Union[float, List[float]]] = None, image_mean: Optional[Union[float, List[float]]] = None,
image_std: Optional[Union[float, List[float]]] = None, image_std: Optional[Union[float, List[float]]] = None,
do_convert_rgb: Optional[bool] = None,
num_labels: Optional[int] = None, num_labels: Optional[int] = None,
return_tensors: Optional[Union[str, TensorType]] = None, return_tensors: Optional[Union[str, TensorType]] = None,
data_format: Union[str, ChannelDimension] = ChannelDimension.FIRST, data_format: Union[str, ChannelDimension] = ChannelDimension.FIRST,
...@@ -440,9 +419,12 @@ class SegGptImageProcessor(BaseImageProcessor): ...@@ -440,9 +419,12 @@ class SegGptImageProcessor(BaseImageProcessor):
Prompt image to _preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255. If Prompt image to _preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255. If
passing in images with pixel values between 0 and 1, set `do_rescale=False`. passing in images with pixel values between 0 and 1, set `do_rescale=False`.
prompt_masks (`ImageInput`): prompt_masks (`ImageInput`):
Prompt mask from prompt image to _preprocess. Expects a single or batch of masks. If the mask masks are Prompt mask from prompt image to _preprocess that specify prompt_masks value in the preprocessed output.
a single channel then it will be converted to RGB using the palette if `self.num_labels` is specified Can either be in the format of segmentation maps (no channels) or RGB images. If in the format of
or by just repeating the channel if not. If the mask is already in RGB format, it will be passed through. RGB images, `do_convert_rgb` should be set to `False`. If in the format of segmentation maps, `num_labels`
specifying `num_labels` is recommended to build a palette to map the prompt mask from a single channel to
a 3 channel RGB. If `num_labels` is not specified, the prompt mask will be duplicated across the channel
dimension.
do_resize (`bool`, *optional*, defaults to `self.do_resize`): do_resize (`bool`, *optional*, defaults to `self.do_resize`):
Whether to resize the image. Whether to resize the image.
size (`Dict[str, int]`, *optional*, defaults to `self.size`): size (`Dict[str, int]`, *optional*, defaults to `self.size`):
...@@ -461,6 +443,16 @@ class SegGptImageProcessor(BaseImageProcessor): ...@@ -461,6 +443,16 @@ class SegGptImageProcessor(BaseImageProcessor):
Image mean to use if `do_normalize` is set to `True`. Image mean to use if `do_normalize` is set to `True`.
image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`): image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`):
Image standard deviation to use if `do_normalize` is set to `True`. Image standard deviation to use if `do_normalize` is set to `True`.
do_convert_rgb (`bool`, *optional*, defaults to `self.do_convert_rgb`):
Whether to convert the prompt mask to RGB format. If `num_labels` is specified, a palette will be built
to map the prompt mask from a single channel to a 3 channel RGB. If unset, the prompt mask is duplicated
across the channel dimension. Must be set to `False` if the prompt mask is already in RGB format.
num_labels: (`int`, *optional*):
Number of classes in the segmentation task (excluding the background). If specified, a palette will be
built, assuming that class_idx 0 is the background, to map the prompt mask from a plain segmentation map
with no channels to a 3 channel RGB. Not specifying this will result in the prompt mask either being passed
through as is if it is already in RGB format (if `do_convert_rgb` is false) or being duplicated
across the channel dimension.
return_tensors (`str` or `TensorType`, *optional*): return_tensors (`str` or `TensorType`, *optional*):
The type of tensors to return. Can be one of: The type of tensors to return. Can be one of:
- Unset: Return a list of `np.ndarray`. - Unset: Return a list of `np.ndarray`.
...@@ -479,11 +471,6 @@ class SegGptImageProcessor(BaseImageProcessor): ...@@ -479,11 +471,6 @@ class SegGptImageProcessor(BaseImageProcessor):
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
- `"none"` or `ChannelDimension.NONE`: image in (height, width) format. - `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
num_labels: (`int`, *optional*):
Number of classes in the segmentation task (excluding the background). If specified, a palette will be
built, assuming that class_idx 0 is the background, to map the prompt mask from a single class_idx
channel to a 3 channel RGB. Not specifying this will result in the prompt mask either being passed
through as is if it is already in RGB format or being duplicated across the channel dimension.
""" """
if all(v is None for v in [images, prompt_images, prompt_masks]): if all(v is None for v in [images, prompt_images, prompt_masks]):
raise ValueError("At least one of images, prompt_images, prompt_masks must be specified.") raise ValueError("At least one of images, prompt_images, prompt_masks must be specified.")
...@@ -502,6 +489,7 @@ class SegGptImageProcessor(BaseImageProcessor): ...@@ -502,6 +489,7 @@ class SegGptImageProcessor(BaseImageProcessor):
do_normalize=do_normalize, do_normalize=do_normalize,
image_mean=image_mean, image_mean=image_mean,
image_std=image_std, image_std=image_std,
do_convert_rgb=False,
data_format=data_format, data_format=data_format,
input_data_format=input_data_format, input_data_format=input_data_format,
**kwargs, **kwargs,
...@@ -521,6 +509,7 @@ class SegGptImageProcessor(BaseImageProcessor): ...@@ -521,6 +509,7 @@ class SegGptImageProcessor(BaseImageProcessor):
do_normalize=do_normalize, do_normalize=do_normalize,
image_mean=image_mean, image_mean=image_mean,
image_std=image_std, image_std=image_std,
do_convert_rgb=False,
data_format=data_format, data_format=data_format,
input_data_format=input_data_format, input_data_format=input_data_format,
**kwargs, **kwargs,
...@@ -531,7 +520,6 @@ class SegGptImageProcessor(BaseImageProcessor): ...@@ -531,7 +520,6 @@ class SegGptImageProcessor(BaseImageProcessor):
if prompt_masks is not None: if prompt_masks is not None:
prompt_masks = self._preprocess_step( prompt_masks = self._preprocess_step(
prompt_masks, prompt_masks,
is_mask=True,
do_resize=do_resize, do_resize=do_resize,
size=size, size=size,
resample=PILImageResampling.NEAREST, resample=PILImageResampling.NEAREST,
...@@ -540,9 +528,10 @@ class SegGptImageProcessor(BaseImageProcessor): ...@@ -540,9 +528,10 @@ class SegGptImageProcessor(BaseImageProcessor):
do_normalize=do_normalize, do_normalize=do_normalize,
image_mean=image_mean, image_mean=image_mean,
image_std=image_std, image_std=image_std,
do_convert_rgb=do_convert_rgb,
num_labels=num_labels,
data_format=data_format, data_format=data_format,
input_data_format=input_data_format, input_data_format=input_data_format,
num_labels=num_labels,
**kwargs, **kwargs,
) )
......
...@@ -30,6 +30,8 @@ if is_torch_available(): ...@@ -30,6 +30,8 @@ if is_torch_available():
from transformers.models.seggpt.modeling_seggpt import SegGptImageSegmentationOutput from transformers.models.seggpt.modeling_seggpt import SegGptImageSegmentationOutput
if is_vision_available(): if is_vision_available():
from PIL import Image
from transformers import SegGptImageProcessor from transformers import SegGptImageProcessor
...@@ -147,7 +149,7 @@ class SegGptImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase): ...@@ -147,7 +149,7 @@ class SegGptImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase):
mask_rgb = mask_binary.convert("RGB") mask_rgb = mask_binary.convert("RGB")
inputs_binary = image_processor(images=None, prompt_masks=mask_binary, return_tensors="pt") inputs_binary = image_processor(images=None, prompt_masks=mask_binary, return_tensors="pt")
inputs_rgb = image_processor(images=None, prompt_masks=mask_rgb, return_tensors="pt") inputs_rgb = image_processor(images=None, prompt_masks=mask_rgb, return_tensors="pt", do_convert_rgb=False)
self.assertTrue((inputs_binary["prompt_masks"] == inputs_rgb["prompt_masks"]).all().item()) self.assertTrue((inputs_binary["prompt_masks"] == inputs_rgb["prompt_masks"]).all().item())
...@@ -196,7 +198,11 @@ class SegGptImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase): ...@@ -196,7 +198,11 @@ class SegGptImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase):
image_processor = SegGptImageProcessor.from_pretrained("BAAI/seggpt-vit-large") image_processor = SegGptImageProcessor.from_pretrained("BAAI/seggpt-vit-large")
inputs = image_processor( inputs = image_processor(
images=input_image, prompt_images=prompt_image, prompt_masks=prompt_mask, return_tensors="pt" images=input_image,
prompt_images=prompt_image,
prompt_masks=prompt_mask,
return_tensors="pt",
do_convert_rgb=False,
) )
# Verify pixel values # Verify pixel values
...@@ -229,3 +235,76 @@ class SegGptImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase): ...@@ -229,3 +235,76 @@ class SegGptImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase):
torch.allclose(inputs.prompt_pixel_values[0, :, :3, :3], expected_prompt_pixel_values, atol=1e-4) torch.allclose(inputs.prompt_pixel_values[0, :, :3, :3], expected_prompt_pixel_values, atol=1e-4)
) )
self.assertTrue(torch.allclose(inputs.prompt_masks[0, :, :3, :3], expected_prompt_masks, atol=1e-4)) self.assertTrue(torch.allclose(inputs.prompt_masks[0, :, :3, :3], expected_prompt_masks, atol=1e-4))
def test_prompt_mask_equivalence(self):
image_processor = self.image_processing_class(**self.image_processor_dict)
image_size = self.image_processor_tester.image_size
# Single Mask Examples
expected_single_shape = [1, 3, image_size, image_size]
# Single Semantic Map (2D)
image_np_2d = np.ones((image_size, image_size))
image_pt_2d = torch.ones((image_size, image_size))
image_pil_2d = Image.fromarray(image_np_2d)
inputs_np_2d = image_processor(images=None, prompt_masks=image_np_2d, return_tensors="pt")
inputs_pt_2d = image_processor(images=None, prompt_masks=image_pt_2d, return_tensors="pt")
inputs_pil_2d = image_processor(images=None, prompt_masks=image_pil_2d, return_tensors="pt")
self.assertTrue((inputs_np_2d["prompt_masks"] == inputs_pt_2d["prompt_masks"]).all().item())
self.assertTrue((inputs_np_2d["prompt_masks"] == inputs_pil_2d["prompt_masks"]).all().item())
self.assertEqual(list(inputs_np_2d["prompt_masks"].shape), expected_single_shape)
# Single RGB Images (3D)
image_np_3d = np.ones((3, image_size, image_size))
image_pt_3d = torch.ones((3, image_size, image_size))
image_pil_3d = Image.fromarray(image_np_3d.transpose(1, 2, 0).astype(np.uint8))
inputs_np_3d = image_processor(
images=None, prompt_masks=image_np_3d, return_tensors="pt", do_convert_rgb=False
)
inputs_pt_3d = image_processor(
images=None, prompt_masks=image_pt_3d, return_tensors="pt", do_convert_rgb=False
)
inputs_pil_3d = image_processor(
images=None, prompt_masks=image_pil_3d, return_tensors="pt", do_convert_rgb=False
)
self.assertTrue((inputs_np_3d["prompt_masks"] == inputs_pt_3d["prompt_masks"]).all().item())
self.assertTrue((inputs_np_3d["prompt_masks"] == inputs_pil_3d["prompt_masks"]).all().item())
self.assertEqual(list(inputs_np_3d["prompt_masks"].shape), expected_single_shape)
# Batched Examples
expected_batched_shape = [2, 3, image_size, image_size]
# Batched Semantic Maps (3D)
image_np_2d_batched = np.ones((2, image_size, image_size))
image_pt_2d_batched = torch.ones((2, image_size, image_size))
inputs_np_2d_batched = image_processor(images=None, prompt_masks=image_np_2d_batched, return_tensors="pt")
inputs_pt_2d_batched = image_processor(images=None, prompt_masks=image_pt_2d_batched, return_tensors="pt")
self.assertTrue((inputs_np_2d_batched["prompt_masks"] == inputs_pt_2d_batched["prompt_masks"]).all().item())
self.assertEqual(list(inputs_np_2d_batched["prompt_masks"].shape), expected_batched_shape)
# Batched RGB images
image_np_4d = np.ones((2, 3, image_size, image_size))
image_pt_4d = torch.ones((2, 3, image_size, image_size))
inputs_np_4d = image_processor(
images=None, prompt_masks=image_np_4d, return_tensors="pt", do_convert_rgb=False
)
inputs_pt_4d = image_processor(
images=None, prompt_masks=image_pt_4d, return_tensors="pt", do_convert_rgb=False
)
self.assertTrue((inputs_np_4d["prompt_masks"] == inputs_pt_4d["prompt_masks"]).all().item())
self.assertEqual(list(inputs_np_4d["prompt_masks"].shape), expected_batched_shape)
# Comparing Single and Batched Examples
self.assertTrue((inputs_np_2d["prompt_masks"][0] == inputs_np_3d["prompt_masks"][0]).all().item())
self.assertTrue((inputs_np_2d_batched["prompt_masks"][0] == inputs_np_2d["prompt_masks"][0]).all().item())
self.assertTrue((inputs_np_2d_batched["prompt_masks"][0] == inputs_np_3d["prompt_masks"][0]).all().item())
self.assertTrue((inputs_np_2d_batched["prompt_masks"][0] == inputs_np_4d["prompt_masks"][0]).all().item())
self.assertTrue((inputs_np_2d_batched["prompt_masks"][0] == inputs_np_3d["prompt_masks"][0]).all().item())
...@@ -363,7 +363,11 @@ class SegGptModelIntegrationTest(unittest.TestCase): ...@@ -363,7 +363,11 @@ class SegGptModelIntegrationTest(unittest.TestCase):
prompt_mask = masks[0] prompt_mask = masks[0]
inputs = image_processor( inputs = image_processor(
images=input_image, prompt_images=prompt_image, prompt_masks=prompt_mask, return_tensors="pt" images=input_image,
prompt_images=prompt_image,
prompt_masks=prompt_mask,
return_tensors="pt",
do_convert_rgb=False,
) )
inputs = inputs.to(torch_device) inputs = inputs.to(torch_device)
...@@ -404,7 +408,11 @@ class SegGptModelIntegrationTest(unittest.TestCase): ...@@ -404,7 +408,11 @@ class SegGptModelIntegrationTest(unittest.TestCase):
prompt_masks = [masks[0], masks[2]] prompt_masks = [masks[0], masks[2]]
inputs = image_processor( inputs = image_processor(
images=input_images, prompt_images=prompt_images, prompt_masks=prompt_masks, return_tensors="pt" images=input_images,
prompt_images=prompt_images,
prompt_masks=prompt_masks,
return_tensors="pt",
do_convert_rgb=False,
) )
inputs = {k: v.to(torch_device) for k, v in inputs.items()} inputs = {k: v.to(torch_device) for k, v in inputs.items()}
...@@ -437,10 +445,16 @@ class SegGptModelIntegrationTest(unittest.TestCase): ...@@ -437,10 +445,16 @@ class SegGptModelIntegrationTest(unittest.TestCase):
prompt_mask = masks[0] prompt_mask = masks[0]
inputs = image_processor( inputs = image_processor(
images=input_image, prompt_masks=prompt_mask, prompt_images=prompt_image, return_tensors="pt" images=input_image,
prompt_masks=prompt_mask,
prompt_images=prompt_image,
return_tensors="pt",
do_convert_rgb=False,
).to(torch_device) ).to(torch_device)
labels = image_processor(images=None, prompt_masks=label, return_tensors="pt")["prompt_masks"].to(torch_device) labels = image_processor(images=None, prompt_masks=label, return_tensors="pt", do_convert_rgb=False)[
"prompt_masks"
].to(torch_device)
bool_masked_pos = prepare_bool_masked_pos(model.config).to(torch_device) bool_masked_pos = prepare_bool_masked_pos(model.config).to(torch_device)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment