Unverified Commit 4c7e8d09 authored by amyeroberts's avatar amyeroberts Committed by GitHub
Browse files

Add object detection + segmentation transforms (#20003)



* Add transforms for object detection

* Update src/transformers/image_transforms.py
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>

* Better var names & docstring

* Remove unused var desc in docstring

* Update src/transformers/image_transforms.py
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>
parent 163ac3d3
...@@ -21,8 +21,16 @@ Most of those are only useful if you are studying the code of the image processo ...@@ -21,8 +21,16 @@ Most of those are only useful if you are studying the code of the image processo
[[autodoc]] image_transforms.center_crop [[autodoc]] image_transforms.center_crop
[[autodoc]] image_transforms.center_to_corners_format
[[autodoc]] image_transforms.corners_to_center_format
[[autodoc]] image_transforms.id_to_rgb
[[autodoc]] image_transforms.normalize [[autodoc]] image_transforms.normalize
[[autodoc]] image_transforms.rgb_to_id
[[autodoc]] image_transforms.rescale [[autodoc]] image_transforms.rescale
[[autodoc]] image_transforms.resize [[autodoc]] image_transforms.resize
......
...@@ -18,7 +18,7 @@ from typing import TYPE_CHECKING, Iterable, List, Optional, Tuple, Union ...@@ -18,7 +18,7 @@ from typing import TYPE_CHECKING, Iterable, List, Optional, Tuple, Union
import numpy as np import numpy as np
from transformers.image_utils import PILImageResampling from transformers.utils import TensorType
from transformers.utils.import_utils import is_flax_available, is_tf_available, is_torch_available, is_vision_available from transformers.utils.import_utils import is_flax_available, is_tf_available, is_torch_available, is_vision_available
...@@ -27,6 +27,7 @@ if is_vision_available(): ...@@ -27,6 +27,7 @@ if is_vision_available():
from .image_utils import ( from .image_utils import (
ChannelDimension, ChannelDimension,
PILImageResampling,
get_channel_dimension_axis, get_channel_dimension_axis,
get_image_size, get_image_size,
infer_channel_dimension_format, infer_channel_dimension_format,
...@@ -108,7 +109,7 @@ def rescale( ...@@ -108,7 +109,7 @@ def rescale(
def to_pil_image( def to_pil_image(
image: Union[np.ndarray, PIL.Image.Image, "torch.Tensor", "tf.Tensor", "jnp.Tensor"], image: Union[np.ndarray, PIL.Image.Image, "torch.Tensor", "tf.Tensor", "jnp.ndarray"],
do_rescale: Optional[bool] = None, do_rescale: Optional[bool] = None,
) -> PIL.Image.Image: ) -> PIL.Image.Image:
""" """
...@@ -300,6 +301,9 @@ def normalize( ...@@ -300,6 +301,9 @@ def normalize(
image = to_numpy_array(image) image = to_numpy_array(image)
image = rescale(image, scale=1 / 255) image = rescale(image, scale=1 / 255)
if not isinstance(image, np.ndarray):
raise ValueError("image must be a numpy array")
input_data_format = infer_channel_dimension_format(image) input_data_format = infer_channel_dimension_format(image)
channel_axis = get_channel_dimension_axis(image) channel_axis = get_channel_dimension_axis(image)
num_channels = image.shape[channel_axis] num_channels = image.shape[channel_axis]
...@@ -420,3 +424,147 @@ def center_crop( ...@@ -420,3 +424,147 @@ def center_crop(
new_image = to_pil_image(new_image) new_image = to_pil_image(new_image)
return new_image return new_image
def _center_to_corners_format_torch(bboxes_center: "torch.Tensor") -> "torch.Tensor":
center_x, center_y, width, height = bboxes_center.unbind(-1)
bbox_corners = torch.stack(
# top left x, top left y, bottom right x, bottom right y
[(center_x - 0.5 * width), (center_y - 0.5 * height), (center_x + 0.5 * width), (center_y + 0.5 * height)],
dim=-1,
)
return bbox_corners
def _center_to_corners_format_numpy(bboxes_center: np.ndarray) -> np.ndarray:
center_x, center_y, width, height = bboxes_center.T
bboxes_corners = np.stack(
# top left x, top left y, bottom right x, bottom right y
[center_x - 0.5 * width, center_y - 0.5 * height, center_x + 0.5 * width, center_y + 0.5 * height],
axis=-1,
)
return bboxes_corners
def _center_to_corners_format_tf(bboxes_center: "tf.Tensor") -> "tf.Tensor":
center_x, center_y, width, height = tf.unstack(bboxes_center, axis=-1)
bboxes_corners = tf.stack(
# top left x, top left y, bottom right x, bottom right y
[center_x - 0.5 * width, center_y - 0.5 * height, center_x + 0.5 * width, center_y + 0.5 * height],
axis=-1,
)
return bboxes_corners
# 2 functions below inspired by https://github.com/facebookresearch/detr/blob/master/util/box_ops.py
def center_to_corners_format(bboxes_center: TensorType) -> TensorType:
"""
Converts bounding boxes from center format to corners format.
center format: contains the coordinate for the center of the box and its width, height dimensions
(center_x, center_y, width, height)
corners format: contains the coodinates for the top-left and bottom-right corners of the box
(top_left_x, top_left_y, bottom_right_x, bottom_right_y)
"""
# Function is used during model forward pass, so we use the input framework if possible, without
# converting to numpy
if is_torch_tensor(bboxes_center):
return _center_to_corners_format_torch(bboxes_center)
elif isinstance(bboxes_center, np.ndarray):
return _center_to_corners_format_numpy(bboxes_center)
elif is_tf_tensor(bboxes_center):
return _center_to_corners_format_tf(bboxes_center)
raise ValueError(f"Unsupported input type {type(bboxes_center)}")
def _corners_to_center_format_torch(bboxes_corners: "torch.Tensor") -> "torch.Tensor":
top_left_x, top_left_y, bottom_right_x, bottom_right_y = bboxes_corners.unbind(-1)
b = [
(top_left_x + bottom_right_x) / 2, # center x
(top_left_y + bottom_right_y) / 2, # center y
(bottom_right_x - top_left_x), # width
(bottom_right_y - top_left_y), # height
]
return torch.stack(b, dim=-1)
def _corners_to_center_format_numpy(bboxes_corners: np.ndarray) -> np.ndarray:
top_left_x, top_left_y, bottom_right_x, bottom_right_y = bboxes_corners.T
bboxes_center = np.stack(
[
(top_left_x + bottom_right_x) / 2, # center x
(top_left_y + bottom_right_y) / 2, # center y
(bottom_right_x - top_left_x), # width
(bottom_right_y - top_left_y), # height
],
axis=-1,
)
return bboxes_center
def _corners_to_center_format_tf(bboxes_corners: "tf.Tensor") -> "tf.Tensor":
top_left_x, top_left_y, bottom_right_x, bottom_right_y = tf.unstack(bboxes_corners, axis=-1)
bboxes_center = tf.stack(
[
(top_left_x + bottom_right_x) / 2, # center x
(top_left_y + bottom_right_y) / 2, # center y
(bottom_right_x - top_left_x), # width
(bottom_right_y - top_left_y), # height
],
axis=-1,
)
return bboxes_center
def corners_to_center_format(bboxes_corners: TensorType) -> TensorType:
"""
Converts bounding boxes from corners format to center format.
corners format: contains the coodinates for the top-left and bottom-right corners of the box
(top_left_x, top_left_y, bottom_right_x, bottom_right_y)
center format: contains the coordinate for the center of the box and its the width, height dimensions
(center_x, center_y, width, height)
"""
# Inverse function accepts different input types so implemented here too
if is_torch_tensor(bboxes_corners):
return _corners_to_center_format_torch(bboxes_corners)
elif isinstance(bboxes_corners, np.ndarray):
return _corners_to_center_format_numpy(bboxes_corners)
elif is_tf_tensor(bboxes_corners):
return _corners_to_center_format_tf(bboxes_corners)
raise ValueError(f"Unsupported input type {type(bboxes_corners)}")
# 2 functions below copied from https://github.com/cocodataset/panopticapi/blob/master/panopticapi/utils.py
# Copyright (c) 2018, Alexander Kirillov
# All rights reserved.
def rgb_to_id(color):
"""
Converts RGB color to unique ID.
"""
if isinstance(color, np.ndarray) and len(color.shape) == 3:
if color.dtype == np.uint8:
color = color.astype(np.int32)
return color[:, :, 0] + 256 * color[:, :, 1] + 256 * 256 * color[:, :, 2]
return int(color[0] + 256 * color[1] + 256 * 256 * color[2])
def id_to_rgb(id_map):
"""
Converts unique ID to RGB color.
"""
if isinstance(id_map, np.ndarray):
id_map_copy = id_map.copy()
rgb_shape = tuple(list(id_map.shape) + [3])
rgb_map = np.zeros(rgb_shape, dtype=np.uint8)
for i in range(3):
rgb_map[..., i] = id_map_copy % 256
id_map_copy //= 256
return rgb_map
color = []
for _ in range(3):
color.append(id_map % 256)
id_map //= 256
return color
...@@ -36,9 +36,13 @@ if is_vision_available(): ...@@ -36,9 +36,13 @@ if is_vision_available():
from transformers.image_transforms import ( from transformers.image_transforms import (
center_crop, center_crop,
center_to_corners_format,
corners_to_center_format,
get_resize_output_image_size, get_resize_output_image_size,
id_to_rgb,
normalize, normalize,
resize, resize,
rgb_to_id,
to_channel_dimension_format, to_channel_dimension_format,
to_pil_image, to_pil_image,
) )
...@@ -178,6 +182,11 @@ class ImageTransformsTester(unittest.TestCase): ...@@ -178,6 +182,11 @@ class ImageTransformsTester(unittest.TestCase):
def test_normalize(self): def test_normalize(self):
image = np.random.randint(0, 256, (224, 224, 3)) / 255 image = np.random.randint(0, 256, (224, 224, 3)) / 255
# Test that exception is raised if inputs are incorrect
# Not a numpy array image
with self.assertRaises(ValueError):
normalize(5, 5, 5)
# Number of mean values != number of channels # Number of mean values != number of channels
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
normalize(image, mean=(0.5, 0.6), std=1) normalize(image, mean=(0.5, 0.6), std=1)
...@@ -219,3 +228,64 @@ class ImageTransformsTester(unittest.TestCase): ...@@ -219,3 +228,64 @@ class ImageTransformsTester(unittest.TestCase):
self.assertIsInstance(cropped_image, np.ndarray) self.assertIsInstance(cropped_image, np.ndarray)
self.assertEqual(cropped_image.shape, (300, 260, 3)) self.assertEqual(cropped_image.shape, (300, 260, 3))
self.assertTrue(np.allclose(cropped_image, expected_image)) self.assertTrue(np.allclose(cropped_image, expected_image))
def test_center_to_corners_format(self):
bbox_center = np.array([[10, 20, 4, 8], [15, 16, 3, 4]])
expected = np.array([[8, 16, 12, 24], [13.5, 14, 16.5, 18]])
self.assertTrue(np.allclose(center_to_corners_format(bbox_center), expected))
# Check that the function and inverse function are inverse of each other
self.assertTrue(np.allclose(corners_to_center_format(center_to_corners_format(bbox_center)), bbox_center))
def test_corners_to_center_format(self):
bbox_corners = np.array([[8, 16, 12, 24], [13.5, 14, 16.5, 18]])
expected = np.array([[10, 20, 4, 8], [15, 16, 3, 4]])
self.assertTrue(np.allclose(corners_to_center_format(bbox_corners), expected))
# Check that the function and inverse function are inverse of each other
self.assertTrue(np.allclose(center_to_corners_format(corners_to_center_format(bbox_corners)), bbox_corners))
def test_rgb_to_id(self):
# test list input
rgb = [125, 4, 255]
self.assertEqual(rgb_to_id(rgb), 16712829)
# test numpy array input
color = np.array(
[
[
[213, 54, 165],
[88, 207, 39],
[156, 108, 128],
],
[
[183, 194, 46],
[137, 58, 88],
[114, 131, 233],
],
]
)
expected = np.array([[10827477, 2608984, 8416412], [3064503, 5782153, 15303538]])
self.assertTrue(np.allclose(rgb_to_id(color), expected))
def test_id_to_rgb(self):
# test int input
self.assertEqual(id_to_rgb(16712829), [125, 4, 255])
# test array input
id_array = np.array([[10827477, 2608984, 8416412], [3064503, 5782153, 15303538]])
color = np.array(
[
[
[213, 54, 165],
[88, 207, 39],
[156, 108, 128],
],
[
[183, 194, 46],
[137, 58, 88],
[114, 131, 233],
],
]
)
self.assertTrue(np.allclose(id_to_rgb(id_array), color))
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment