Unverified Commit cddad9ca authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

support grayscale / RGB alpha conversions (#5567)

* support grayscale / RGB alpha conversions

* use _max_valu from stable

* remove extra copy for PIL conversion

* simplify test image generation for color spaces with alpha channel

* use common _max_value in tests

* replace dynamically created dicts with if/else

* make color space conversion more explicit

* make even more explicit

* simplify alpha image generation

* fix if / elif

* add error for unknown conversions

* rename RGBA to RGB_ALPHA

* cleanup

* GRAYSCALE to GRAY
parent 24c0a147
......@@ -102,7 +102,14 @@ class TestSmoke:
(
transform,
itertools.chain.from_iterable(
fn(dtypes=[torch.uint8], extra_dims=[(4,)])
fn(
color_spaces=[
features.ColorSpace.GRAY,
features.ColorSpace.RGB,
],
dtypes=[torch.uint8],
extra_dims=[(4,)],
)
for fn in [
make_images,
make_vanilla_tensor_images,
......@@ -152,3 +159,32 @@ class TestSmoke:
)
def test_random_resized_crop(self, transform, input):
transform(input)
@parametrize(
[
(
transforms.ConvertImageColorSpace(color_space=new_color_space, old_color_space=old_color_space),
itertools.chain.from_iterable(
[
fn(color_spaces=[old_color_space])
for fn in (
make_images,
make_vanilla_tensor_images,
make_pil_images,
)
]
),
)
for old_color_space, new_color_space in itertools.product(
[
features.ColorSpace.GRAY,
features.ColorSpace.GRAY_ALPHA,
features.ColorSpace.RGB,
features.ColorSpace.RGB_ALPHA,
],
repeat=2,
)
]
)
def test_convert_image_color_space(self, transform, input):
transform(input)
......@@ -7,33 +7,44 @@ import torchvision.prototype.transforms.functional as F
from torch import jit
from torch.nn.functional import one_hot
from torchvision.prototype import features
from torchvision.transforms.functional_tensor import _max_value as get_max_value
make_tensor = functools.partial(torch.testing.make_tensor, device="cpu")
def make_image(size=None, *, color_space, extra_dims=(), dtype=torch.float32):
def make_image(size=None, *, color_space, extra_dims=(), dtype=torch.float32, constant_alpha=True):
size = size or torch.randint(16, 33, (2,)).tolist()
num_channels = {
features.ColorSpace.GRAYSCALE: 1,
features.ColorSpace.RGB: 3,
}[color_space]
try:
num_channels = {
features.ColorSpace.GRAY: 1,
features.ColorSpace.GRAY_ALPHA: 2,
features.ColorSpace.RGB: 3,
features.ColorSpace.RGB_ALPHA: 4,
}[color_space]
except KeyError as error:
raise pytest.UsageError() from error
shape = (*extra_dims, num_channels, *size)
if dtype.is_floating_point:
data = torch.rand(shape, dtype=dtype)
else:
data = torch.randint(0, torch.iinfo(dtype).max, shape, dtype=dtype)
max_value = get_max_value(dtype)
data = make_tensor(shape, low=0, high=max_value, dtype=dtype)
if color_space in {features.ColorSpace.GRAY_ALPHA, features.ColorSpace.RGB_ALPHA} and constant_alpha:
data[..., -1, :, :] = max_value
return features.Image(data, color_space=color_space)
make_grayscale_image = functools.partial(make_image, color_space=features.ColorSpace.GRAYSCALE)
make_grayscale_image = functools.partial(make_image, color_space=features.ColorSpace.GRAY)
make_rgb_image = functools.partial(make_image, color_space=features.ColorSpace.RGB)
def make_images(
sizes=((16, 16), (7, 33), (31, 9)),
color_spaces=(features.ColorSpace.GRAYSCALE, features.ColorSpace.RGB),
color_spaces=(
features.ColorSpace.GRAY,
features.ColorSpace.GRAY_ALPHA,
features.ColorSpace.RGB,
features.ColorSpace.RGB_ALPHA,
),
dtypes=(torch.float32, torch.uint8),
extra_dims=((4,), (2, 3)),
):
......@@ -48,15 +59,12 @@ def randint_with_tensor_bounds(arg1, arg2=None, **kwargs):
low, high = torch.broadcast_tensors(
*[torch.as_tensor(arg) for arg in ((0, arg1) if arg2 is None else (arg1, arg2))]
)
try:
return torch.stack(
[
torch.randint(low_scalar, high_scalar, (), **kwargs)
for low_scalar, high_scalar in zip(low.flatten().tolist(), high.flatten().tolist())
]
).reshape(low.shape)
except RuntimeError as error:
raise error
return torch.stack(
[
torch.randint(low_scalar, high_scalar, (), **kwargs)
for low_scalar, high_scalar in zip(low.flatten().tolist(), high.flatten().tolist())
]
).reshape(low.shape)
def make_bounding_box(*, format, image_size=(32, 32), extra_dims=(), dtype=torch.int64):
......@@ -83,8 +91,8 @@ def make_bounding_box(*, format, image_size=(32, 32), extra_dims=(), dtype=torch
w = randint_with_tensor_bounds(1, torch.minimum(cx, width - cx) + 1)
h = randint_with_tensor_bounds(1, torch.minimum(cy, width - cy) + 1)
parts = (cx, cy, w, h)
else: # format == features.BoundingBoxFormat._SENTINEL:
raise ValueError()
else:
raise pytest.UsageError()
return features.BoundingBox(torch.stack(parts, dim=-1).to(dtype), format=format, image_size=image_size)
......
......@@ -15,8 +15,23 @@ from ._feature import _Feature
class ColorSpace(StrEnum):
OTHER = StrEnum.auto()
GRAYSCALE = StrEnum.auto()
GRAY = StrEnum.auto()
GRAY_ALPHA = StrEnum.auto()
RGB = StrEnum.auto()
RGB_ALPHA = StrEnum.auto()
@classmethod
def from_pil_mode(cls, mode: str) -> ColorSpace:
if mode == "L":
return cls.GRAY
elif mode == "LA":
return cls.GRAY_ALPHA
elif mode == "RGB":
return cls.RGB
elif mode == "RGBA":
return cls.RGB_ALPHA
else:
return cls.OTHER
class Image(_Feature):
......@@ -71,13 +86,17 @@ class Image(_Feature):
if data.ndim < 2:
return ColorSpace.OTHER
elif data.ndim == 2:
return ColorSpace.GRAYSCALE
return ColorSpace.GRAY
num_channels = data.shape[-3]
if num_channels == 1:
return ColorSpace.GRAYSCALE
return ColorSpace.GRAY
elif num_channels == 2:
return ColorSpace.GRAY_ALPHA
elif num_channels == 3:
return ColorSpace.RGB
elif num_channels == 4:
return ColorSpace.RGB_ALPHA
else:
return ColorSpace.OTHER
......
......@@ -48,11 +48,11 @@ class ConvertImageColorSpace(Transform):
super().__init__()
if isinstance(color_space, str):
color_space = features.ColorSpace[color_space]
color_space = features.ColorSpace.from_str(color_space)
self.color_space = color_space
if isinstance(old_color_space, str):
old_color_space = features.ColorSpace[old_color_space]
old_color_space = features.ColorSpace.from_str(old_color_space)
self.old_color_space = old_color_space
def _transform(self, input: Any, params: Dict[str, Any]) -> Any:
......@@ -72,13 +72,6 @@ class ConvertImageColorSpace(Transform):
input, old_color_space=self.old_color_space, new_color_space=self.color_space
)
elif isinstance(input, PIL.Image.Image):
old_color_space = {
"L": features.ColorSpace.GRAYSCALE,
"RGB": features.ColorSpace.RGB,
}.get(input.mode, features.ColorSpace.OTHER)
return F.convert_image_color_space_pil(
input, old_color_space=old_color_space, new_color_space=self.color_space
)
return F.convert_image_color_space_pil(input, color_space=self.color_space)
else:
return input
from typing import Tuple, Optional
import PIL.Image
import torch
from torchvision.prototype.features import BoundingBoxFormat, ColorSpace
from torchvision.transforms import functional_tensor as _FT, functional_pil as _FP
get_dimensions_image_tensor = _FT.get_dimensions
get_dimensions_image_pil = _FP.get_dimensions
......@@ -57,41 +58,88 @@ def convert_bounding_box_format(
return bounding_box
def _grayscale_to_rgb_tensor(grayscale: torch.Tensor) -> torch.Tensor:
repeats = [1] * grayscale.ndim
repeats[-3] = 3
return grayscale.repeat(repeats)
def _split_alpha(image: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
return image[..., :-1, :, :], image[..., -1:, :, :]
def convert_image_color_space_tensor(
image: torch.Tensor, old_color_space: ColorSpace, new_color_space: ColorSpace
) -> torch.Tensor:
if new_color_space == old_color_space:
return image.clone()
def _strip_alpha(image: torch.Tensor) -> torch.Tensor:
image, alpha = _split_alpha(image)
if not torch.all(alpha == _FT._max_value(alpha.dtype)):
raise RuntimeError(
"Stripping the alpha channel if it contains values other than the max value is not supported."
)
return image
if old_color_space == ColorSpace.GRAYSCALE:
image = _grayscale_to_rgb_tensor(image)
if new_color_space == ColorSpace.GRAYSCALE:
image = _FT.rgb_to_grayscale(image)
def _add_alpha(image: torch.Tensor, alpha: Optional[torch.Tensor] = None) -> torch.Tensor:
if alpha is None:
shape = list(image.shape)
shape[-3] = 1
alpha = torch.full(shape, _FT._max_value(image.dtype), dtype=image.dtype, device=image.device)
return torch.cat((image, alpha), dim=-3)
return image
def _grayscale_to_rgb_pil(grayscale: PIL.Image.Image) -> PIL.Image.Image:
return grayscale.convert("RGB")
def _gray_to_rgb(grayscale: torch.Tensor) -> torch.Tensor:
repeats = [1] * grayscale.ndim
repeats[-3] = 3
return grayscale.repeat(repeats)
def convert_image_color_space_pil(
image: PIL.Image.Image, old_color_space: ColorSpace, new_color_space: ColorSpace
) -> PIL.Image.Image:
if new_color_space == old_color_space:
return image.copy()
_rgb_to_gray = _FT.rgb_to_grayscale
if old_color_space == ColorSpace.GRAYSCALE:
image = _grayscale_to_rgb_pil(image)
if new_color_space == ColorSpace.GRAYSCALE:
image = _FP.to_grayscale(image)
def convert_image_color_space_tensor(
image: torch.Tensor, old_color_space: ColorSpace, new_color_space: ColorSpace
) -> torch.Tensor:
if new_color_space == old_color_space:
return image.clone()
return image
if old_color_space == ColorSpace.OTHER or new_color_space == ColorSpace.OTHER:
raise RuntimeError(f"Conversion to or from {ColorSpace.OTHER} is not supported.")
if old_color_space == ColorSpace.GRAY and new_color_space == ColorSpace.GRAY_ALPHA:
return _add_alpha(image)
elif old_color_space == ColorSpace.GRAY and new_color_space == ColorSpace.RGB:
return _gray_to_rgb(image)
elif old_color_space == ColorSpace.GRAY and new_color_space == ColorSpace.RGB_ALPHA:
return _add_alpha(_gray_to_rgb(image))
elif old_color_space == ColorSpace.GRAY_ALPHA and new_color_space == ColorSpace.GRAY:
return _strip_alpha(image)
elif old_color_space == ColorSpace.GRAY_ALPHA and new_color_space == ColorSpace.RGB:
return _gray_to_rgb(_strip_alpha(image))
elif old_color_space == ColorSpace.GRAY_ALPHA and new_color_space == ColorSpace.RGB_ALPHA:
image, alpha = _split_alpha(image)
return _add_alpha(_gray_to_rgb(image), alpha)
elif old_color_space == ColorSpace.RGB and new_color_space == ColorSpace.GRAY:
return _rgb_to_gray(image)
elif old_color_space == ColorSpace.RGB and new_color_space == ColorSpace.GRAY_ALPHA:
return _add_alpha(_rgb_to_gray(image))
elif old_color_space == ColorSpace.RGB and new_color_space == ColorSpace.RGB_ALPHA:
return _add_alpha(image)
elif old_color_space == ColorSpace.RGB_ALPHA and new_color_space == ColorSpace.GRAY:
return _rgb_to_gray(_strip_alpha(image))
elif old_color_space == ColorSpace.RGB_ALPHA and new_color_space == ColorSpace.GRAY_ALPHA:
image, alpha = _split_alpha(image)
return _add_alpha(_rgb_to_gray(image), alpha)
elif old_color_space == ColorSpace.RGB_ALPHA and new_color_space == ColorSpace.RGB:
return _strip_alpha(image)
else:
raise RuntimeError(f"Conversion from {old_color_space} to {new_color_space} is not supported.")
_COLOR_SPACE_TO_PIL_MODE = {
ColorSpace.GRAY: "L",
ColorSpace.GRAY_ALPHA: "LA",
ColorSpace.RGB: "RGB",
ColorSpace.RGB_ALPHA: "RGBA",
}
def convert_image_color_space_pil(image: PIL.Image.Image, color_space: ColorSpace) -> PIL.Image.Image:
old_mode = image.mode
try:
new_mode = _COLOR_SPACE_TO_PIL_MODE[color_space]
except KeyError:
raise ValueError(f"Conversion from {ColorSpace.from_pil_mode(old_mode)} to {color_space} is not supported.")
return image.convert(new_mode)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment