"git@developer.sourcefind.cn:OpenDAS/mmcv.git" did not exist on "a3a078eefb703cc22412af0d133e9c378cb3f56f"
Unverified Commit cddad9ca authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

support grayscale / RGB alpha conversions (#5567)

* support grayscale / RGB alpha conversions

* use _max_valu from stable

* remove extra copy for PIL conversion

* simplify test image generation for color spaces with alpha channel

* use common _max_value in tests

* replace dynamically created dicts with if/else

* make color space conversion more explicit

* make even more explicit

* simplify alpha image generation

* fix if / elif

* add error for unknown conversions

* rename RGBA to RGB_ALPHA

* cleanup

* GRAYSCALE to GRAY
parent 24c0a147
...@@ -102,7 +102,14 @@ class TestSmoke: ...@@ -102,7 +102,14 @@ class TestSmoke:
( (
transform, transform,
itertools.chain.from_iterable( itertools.chain.from_iterable(
fn(dtypes=[torch.uint8], extra_dims=[(4,)]) fn(
color_spaces=[
features.ColorSpace.GRAY,
features.ColorSpace.RGB,
],
dtypes=[torch.uint8],
extra_dims=[(4,)],
)
for fn in [ for fn in [
make_images, make_images,
make_vanilla_tensor_images, make_vanilla_tensor_images,
...@@ -152,3 +159,32 @@ class TestSmoke: ...@@ -152,3 +159,32 @@ class TestSmoke:
) )
def test_random_resized_crop(self, transform, input): def test_random_resized_crop(self, transform, input):
transform(input) transform(input)
@parametrize(
[
(
transforms.ConvertImageColorSpace(color_space=new_color_space, old_color_space=old_color_space),
itertools.chain.from_iterable(
[
fn(color_spaces=[old_color_space])
for fn in (
make_images,
make_vanilla_tensor_images,
make_pil_images,
)
]
),
)
for old_color_space, new_color_space in itertools.product(
[
features.ColorSpace.GRAY,
features.ColorSpace.GRAY_ALPHA,
features.ColorSpace.RGB,
features.ColorSpace.RGB_ALPHA,
],
repeat=2,
)
]
)
def test_convert_image_color_space(self, transform, input):
transform(input)
...@@ -7,33 +7,44 @@ import torchvision.prototype.transforms.functional as F ...@@ -7,33 +7,44 @@ import torchvision.prototype.transforms.functional as F
from torch import jit from torch import jit
from torch.nn.functional import one_hot from torch.nn.functional import one_hot
from torchvision.prototype import features from torchvision.prototype import features
from torchvision.transforms.functional_tensor import _max_value as get_max_value
make_tensor = functools.partial(torch.testing.make_tensor, device="cpu") make_tensor = functools.partial(torch.testing.make_tensor, device="cpu")
def make_image(size=None, *, color_space, extra_dims=(), dtype=torch.float32): def make_image(size=None, *, color_space, extra_dims=(), dtype=torch.float32, constant_alpha=True):
size = size or torch.randint(16, 33, (2,)).tolist() size = size or torch.randint(16, 33, (2,)).tolist()
num_channels = { try:
features.ColorSpace.GRAYSCALE: 1, num_channels = {
features.ColorSpace.RGB: 3, features.ColorSpace.GRAY: 1,
}[color_space] features.ColorSpace.GRAY_ALPHA: 2,
features.ColorSpace.RGB: 3,
features.ColorSpace.RGB_ALPHA: 4,
}[color_space]
except KeyError as error:
raise pytest.UsageError() from error
shape = (*extra_dims, num_channels, *size) shape = (*extra_dims, num_channels, *size)
if dtype.is_floating_point: max_value = get_max_value(dtype)
data = torch.rand(shape, dtype=dtype) data = make_tensor(shape, low=0, high=max_value, dtype=dtype)
else: if color_space in {features.ColorSpace.GRAY_ALPHA, features.ColorSpace.RGB_ALPHA} and constant_alpha:
data = torch.randint(0, torch.iinfo(dtype).max, shape, dtype=dtype) data[..., -1, :, :] = max_value
return features.Image(data, color_space=color_space) return features.Image(data, color_space=color_space)
make_grayscale_image = functools.partial(make_image, color_space=features.ColorSpace.GRAYSCALE) make_grayscale_image = functools.partial(make_image, color_space=features.ColorSpace.GRAY)
make_rgb_image = functools.partial(make_image, color_space=features.ColorSpace.RGB) make_rgb_image = functools.partial(make_image, color_space=features.ColorSpace.RGB)
def make_images( def make_images(
sizes=((16, 16), (7, 33), (31, 9)), sizes=((16, 16), (7, 33), (31, 9)),
color_spaces=(features.ColorSpace.GRAYSCALE, features.ColorSpace.RGB), color_spaces=(
features.ColorSpace.GRAY,
features.ColorSpace.GRAY_ALPHA,
features.ColorSpace.RGB,
features.ColorSpace.RGB_ALPHA,
),
dtypes=(torch.float32, torch.uint8), dtypes=(torch.float32, torch.uint8),
extra_dims=((4,), (2, 3)), extra_dims=((4,), (2, 3)),
): ):
...@@ -48,15 +59,12 @@ def randint_with_tensor_bounds(arg1, arg2=None, **kwargs): ...@@ -48,15 +59,12 @@ def randint_with_tensor_bounds(arg1, arg2=None, **kwargs):
low, high = torch.broadcast_tensors( low, high = torch.broadcast_tensors(
*[torch.as_tensor(arg) for arg in ((0, arg1) if arg2 is None else (arg1, arg2))] *[torch.as_tensor(arg) for arg in ((0, arg1) if arg2 is None else (arg1, arg2))]
) )
try: return torch.stack(
return torch.stack( [
[ torch.randint(low_scalar, high_scalar, (), **kwargs)
torch.randint(low_scalar, high_scalar, (), **kwargs) for low_scalar, high_scalar in zip(low.flatten().tolist(), high.flatten().tolist())
for low_scalar, high_scalar in zip(low.flatten().tolist(), high.flatten().tolist()) ]
] ).reshape(low.shape)
).reshape(low.shape)
except RuntimeError as error:
raise error
def make_bounding_box(*, format, image_size=(32, 32), extra_dims=(), dtype=torch.int64): def make_bounding_box(*, format, image_size=(32, 32), extra_dims=(), dtype=torch.int64):
...@@ -83,8 +91,8 @@ def make_bounding_box(*, format, image_size=(32, 32), extra_dims=(), dtype=torch ...@@ -83,8 +91,8 @@ def make_bounding_box(*, format, image_size=(32, 32), extra_dims=(), dtype=torch
w = randint_with_tensor_bounds(1, torch.minimum(cx, width - cx) + 1) w = randint_with_tensor_bounds(1, torch.minimum(cx, width - cx) + 1)
h = randint_with_tensor_bounds(1, torch.minimum(cy, width - cy) + 1) h = randint_with_tensor_bounds(1, torch.minimum(cy, width - cy) + 1)
parts = (cx, cy, w, h) parts = (cx, cy, w, h)
else: # format == features.BoundingBoxFormat._SENTINEL: else:
raise ValueError() raise pytest.UsageError()
return features.BoundingBox(torch.stack(parts, dim=-1).to(dtype), format=format, image_size=image_size) return features.BoundingBox(torch.stack(parts, dim=-1).to(dtype), format=format, image_size=image_size)
......
...@@ -15,8 +15,23 @@ from ._feature import _Feature ...@@ -15,8 +15,23 @@ from ._feature import _Feature
class ColorSpace(StrEnum): class ColorSpace(StrEnum):
OTHER = StrEnum.auto() OTHER = StrEnum.auto()
GRAYSCALE = StrEnum.auto() GRAY = StrEnum.auto()
GRAY_ALPHA = StrEnum.auto()
RGB = StrEnum.auto() RGB = StrEnum.auto()
RGB_ALPHA = StrEnum.auto()
@classmethod
def from_pil_mode(cls, mode: str) -> ColorSpace:
if mode == "L":
return cls.GRAY
elif mode == "LA":
return cls.GRAY_ALPHA
elif mode == "RGB":
return cls.RGB
elif mode == "RGBA":
return cls.RGB_ALPHA
else:
return cls.OTHER
class Image(_Feature): class Image(_Feature):
...@@ -71,13 +86,17 @@ class Image(_Feature): ...@@ -71,13 +86,17 @@ class Image(_Feature):
if data.ndim < 2: if data.ndim < 2:
return ColorSpace.OTHER return ColorSpace.OTHER
elif data.ndim == 2: elif data.ndim == 2:
return ColorSpace.GRAYSCALE return ColorSpace.GRAY
num_channels = data.shape[-3] num_channels = data.shape[-3]
if num_channels == 1: if num_channels == 1:
return ColorSpace.GRAYSCALE return ColorSpace.GRAY
elif num_channels == 2:
return ColorSpace.GRAY_ALPHA
elif num_channels == 3: elif num_channels == 3:
return ColorSpace.RGB return ColorSpace.RGB
elif num_channels == 4:
return ColorSpace.RGB_ALPHA
else: else:
return ColorSpace.OTHER return ColorSpace.OTHER
......
...@@ -48,11 +48,11 @@ class ConvertImageColorSpace(Transform): ...@@ -48,11 +48,11 @@ class ConvertImageColorSpace(Transform):
super().__init__() super().__init__()
if isinstance(color_space, str): if isinstance(color_space, str):
color_space = features.ColorSpace[color_space] color_space = features.ColorSpace.from_str(color_space)
self.color_space = color_space self.color_space = color_space
if isinstance(old_color_space, str): if isinstance(old_color_space, str):
old_color_space = features.ColorSpace[old_color_space] old_color_space = features.ColorSpace.from_str(old_color_space)
self.old_color_space = old_color_space self.old_color_space = old_color_space
def _transform(self, input: Any, params: Dict[str, Any]) -> Any: def _transform(self, input: Any, params: Dict[str, Any]) -> Any:
...@@ -72,13 +72,6 @@ class ConvertImageColorSpace(Transform): ...@@ -72,13 +72,6 @@ class ConvertImageColorSpace(Transform):
input, old_color_space=self.old_color_space, new_color_space=self.color_space input, old_color_space=self.old_color_space, new_color_space=self.color_space
) )
elif isinstance(input, PIL.Image.Image): elif isinstance(input, PIL.Image.Image):
old_color_space = { return F.convert_image_color_space_pil(input, color_space=self.color_space)
"L": features.ColorSpace.GRAYSCALE,
"RGB": features.ColorSpace.RGB,
}.get(input.mode, features.ColorSpace.OTHER)
return F.convert_image_color_space_pil(
input, old_color_space=old_color_space, new_color_space=self.color_space
)
else: else:
return input return input
from typing import Tuple, Optional
import PIL.Image import PIL.Image
import torch import torch
from torchvision.prototype.features import BoundingBoxFormat, ColorSpace from torchvision.prototype.features import BoundingBoxFormat, ColorSpace
from torchvision.transforms import functional_tensor as _FT, functional_pil as _FP from torchvision.transforms import functional_tensor as _FT, functional_pil as _FP
get_dimensions_image_tensor = _FT.get_dimensions get_dimensions_image_tensor = _FT.get_dimensions
get_dimensions_image_pil = _FP.get_dimensions get_dimensions_image_pil = _FP.get_dimensions
...@@ -57,41 +58,88 @@ def convert_bounding_box_format( ...@@ -57,41 +58,88 @@ def convert_bounding_box_format(
return bounding_box return bounding_box
def _grayscale_to_rgb_tensor(grayscale: torch.Tensor) -> torch.Tensor: def _split_alpha(image: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
repeats = [1] * grayscale.ndim return image[..., :-1, :, :], image[..., -1:, :, :]
repeats[-3] = 3
return grayscale.repeat(repeats)
def convert_image_color_space_tensor( def _strip_alpha(image: torch.Tensor) -> torch.Tensor:
image: torch.Tensor, old_color_space: ColorSpace, new_color_space: ColorSpace image, alpha = _split_alpha(image)
) -> torch.Tensor: if not torch.all(alpha == _FT._max_value(alpha.dtype)):
if new_color_space == old_color_space: raise RuntimeError(
return image.clone() "Stripping the alpha channel if it contains values other than the max value is not supported."
)
return image
if old_color_space == ColorSpace.GRAYSCALE:
image = _grayscale_to_rgb_tensor(image)
if new_color_space == ColorSpace.GRAYSCALE: def _add_alpha(image: torch.Tensor, alpha: Optional[torch.Tensor] = None) -> torch.Tensor:
image = _FT.rgb_to_grayscale(image) if alpha is None:
shape = list(image.shape)
shape[-3] = 1
alpha = torch.full(shape, _FT._max_value(image.dtype), dtype=image.dtype, device=image.device)
return torch.cat((image, alpha), dim=-3)
return image
def _gray_to_rgb(grayscale: torch.Tensor) -> torch.Tensor:
def _grayscale_to_rgb_pil(grayscale: PIL.Image.Image) -> PIL.Image.Image: repeats = [1] * grayscale.ndim
return grayscale.convert("RGB") repeats[-3] = 3
return grayscale.repeat(repeats)
def convert_image_color_space_pil( _rgb_to_gray = _FT.rgb_to_grayscale
image: PIL.Image.Image, old_color_space: ColorSpace, new_color_space: ColorSpace
) -> PIL.Image.Image:
if new_color_space == old_color_space:
return image.copy()
if old_color_space == ColorSpace.GRAYSCALE:
image = _grayscale_to_rgb_pil(image)
if new_color_space == ColorSpace.GRAYSCALE: def convert_image_color_space_tensor(
image = _FP.to_grayscale(image) image: torch.Tensor, old_color_space: ColorSpace, new_color_space: ColorSpace
) -> torch.Tensor:
if new_color_space == old_color_space:
return image.clone()
return image if old_color_space == ColorSpace.OTHER or new_color_space == ColorSpace.OTHER:
raise RuntimeError(f"Conversion to or from {ColorSpace.OTHER} is not supported.")
if old_color_space == ColorSpace.GRAY and new_color_space == ColorSpace.GRAY_ALPHA:
return _add_alpha(image)
elif old_color_space == ColorSpace.GRAY and new_color_space == ColorSpace.RGB:
return _gray_to_rgb(image)
elif old_color_space == ColorSpace.GRAY and new_color_space == ColorSpace.RGB_ALPHA:
return _add_alpha(_gray_to_rgb(image))
elif old_color_space == ColorSpace.GRAY_ALPHA and new_color_space == ColorSpace.GRAY:
return _strip_alpha(image)
elif old_color_space == ColorSpace.GRAY_ALPHA and new_color_space == ColorSpace.RGB:
return _gray_to_rgb(_strip_alpha(image))
elif old_color_space == ColorSpace.GRAY_ALPHA and new_color_space == ColorSpace.RGB_ALPHA:
image, alpha = _split_alpha(image)
return _add_alpha(_gray_to_rgb(image), alpha)
elif old_color_space == ColorSpace.RGB and new_color_space == ColorSpace.GRAY:
return _rgb_to_gray(image)
elif old_color_space == ColorSpace.RGB and new_color_space == ColorSpace.GRAY_ALPHA:
return _add_alpha(_rgb_to_gray(image))
elif old_color_space == ColorSpace.RGB and new_color_space == ColorSpace.RGB_ALPHA:
return _add_alpha(image)
elif old_color_space == ColorSpace.RGB_ALPHA and new_color_space == ColorSpace.GRAY:
return _rgb_to_gray(_strip_alpha(image))
elif old_color_space == ColorSpace.RGB_ALPHA and new_color_space == ColorSpace.GRAY_ALPHA:
image, alpha = _split_alpha(image)
return _add_alpha(_rgb_to_gray(image), alpha)
elif old_color_space == ColorSpace.RGB_ALPHA and new_color_space == ColorSpace.RGB:
return _strip_alpha(image)
else:
raise RuntimeError(f"Conversion from {old_color_space} to {new_color_space} is not supported.")
_COLOR_SPACE_TO_PIL_MODE = {
ColorSpace.GRAY: "L",
ColorSpace.GRAY_ALPHA: "LA",
ColorSpace.RGB: "RGB",
ColorSpace.RGB_ALPHA: "RGBA",
}
def convert_image_color_space_pil(image: PIL.Image.Image, color_space: ColorSpace) -> PIL.Image.Image:
old_mode = image.mode
try:
new_mode = _COLOR_SPACE_TO_PIL_MODE[color_space]
except KeyError:
raise ValueError(f"Conversion from {ColorSpace.from_pil_mode(old_mode)} to {color_space} is not supported.")
return image.convert(new_mode)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment