Unverified Commit 5d8d61ac authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

add PermuteChannels transform (#7624)

parent 2ab937a0
...@@ -155,6 +155,7 @@ Color ...@@ -155,6 +155,7 @@ Color
ColorJitter ColorJitter
v2.ColorJitter v2.ColorJitter
v2.RandomChannelPermutation
v2.RandomPhotometricDistort v2.RandomPhotometricDistort
Grayscale Grayscale
v2.Grayscale v2.Grayscale
......
...@@ -124,6 +124,7 @@ class TestSmoke: ...@@ -124,6 +124,7 @@ class TestSmoke:
(transforms.RandomEqualize(p=1.0), None), (transforms.RandomEqualize(p=1.0), None),
(transforms.RandomGrayscale(p=1.0), None), (transforms.RandomGrayscale(p=1.0), None),
(transforms.RandomInvert(p=1.0), None), (transforms.RandomInvert(p=1.0), None),
(transforms.RandomChannelPermutation(), None),
(transforms.RandomPhotometricDistort(p=1.0), None), (transforms.RandomPhotometricDistort(p=1.0), None),
(transforms.RandomPosterize(bits=4, p=1.0), None), (transforms.RandomPosterize(bits=4, p=1.0), None),
(transforms.RandomSolarize(threshold=0.5, p=1.0), None), (transforms.RandomSolarize(threshold=0.5, p=1.0), None),
......
...@@ -2280,3 +2280,61 @@ class TestGetKernel: ...@@ -2280,3 +2280,61 @@ class TestGetKernel:
_register_kernel_internal(F.resize, MyDatapoint, datapoint_wrapper=False)(resize_my_datapoint) _register_kernel_internal(F.resize, MyDatapoint, datapoint_wrapper=False)(resize_my_datapoint)
assert _get_kernel(F.resize, MyDatapoint) is resize_my_datapoint assert _get_kernel(F.resize, MyDatapoint) is resize_my_datapoint
class TestPermuteChannels:
_DEFAULT_PERMUTATION = [2, 0, 1]
@pytest.mark.parametrize(
("kernel", "make_input"),
[
(F.permute_channels_image_tensor, make_image_tensor),
# FIXME
# check_kernel does not support PIL kernel, but it should
(F.permute_channels_image_tensor, make_image),
(F.permute_channels_video, make_video),
],
)
@pytest.mark.parametrize("dtype", [torch.float32, torch.uint8])
@pytest.mark.parametrize("device", cpu_and_cuda())
def test_kernel(self, kernel, make_input, dtype, device):
check_kernel(kernel, make_input(dtype=dtype, device=device), permutation=self._DEFAULT_PERMUTATION)
@pytest.mark.parametrize(
("kernel", "make_input"),
[
(F.permute_channels_image_tensor, make_image_tensor),
(F.permute_channels_image_pil, make_image_pil),
(F.permute_channels_image_tensor, make_image),
(F.permute_channels_video, make_video),
],
)
def test_dispatcher(self, kernel, make_input):
check_dispatcher(F.permute_channels, kernel, make_input(), permutation=self._DEFAULT_PERMUTATION)
@pytest.mark.parametrize(
("kernel", "input_type"),
[
(F.permute_channels_image_tensor, torch.Tensor),
(F.permute_channels_image_pil, PIL.Image.Image),
(F.permute_channels_image_tensor, datapoints.Image),
(F.permute_channels_video, datapoints.Video),
],
)
def test_dispatcher_signature(self, kernel, input_type):
check_dispatcher_kernel_signature_match(F.permute_channels, kernel=kernel, input_type=input_type)
def reference_image_correctness(self, image, permutation):
channel_images = image.split(1, dim=-3)
permuted_channel_images = [channel_images[channel_idx] for channel_idx in permutation]
return datapoints.Image(torch.concat(permuted_channel_images, dim=-3))
@pytest.mark.parametrize("permutation", [[2, 0, 1], [1, 2, 0], [2, 0, 1], [0, 1, 2]])
@pytest.mark.parametrize("batch_dims", [(), (2,), (2, 1)])
def test_image_correctness(self, permutation, batch_dims):
image = make_image(batch_dims=batch_dims)
actual = F.permute_channels(image, permutation=permutation)
expected = self.reference_image_correctness(image, permutation=permutation)
torch.testing.assert_close(actual, expected)
...@@ -11,6 +11,7 @@ from ._color import ( ...@@ -11,6 +11,7 @@ from ._color import (
Grayscale, Grayscale,
RandomAdjustSharpness, RandomAdjustSharpness,
RandomAutocontrast, RandomAutocontrast,
RandomChannelPermutation,
RandomEqualize, RandomEqualize,
RandomGrayscale, RandomGrayscale,
RandomInvert, RandomInvert,
......
...@@ -177,7 +177,27 @@ class ColorJitter(Transform): ...@@ -177,7 +177,27 @@ class ColorJitter(Transform):
return output return output
# TODO: This class seems to be untested class RandomChannelPermutation(Transform):
"""[BETA] Randomly permute the channels of an image or video
.. v2betastatus:: RandomChannelPermutation transform
"""
_transformed_types = (
datapoints.Image,
PIL.Image.Image,
is_simple_tensor,
datapoints.Video,
)
def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
num_channels, *_ = query_chw(flat_inputs)
return dict(permutation=torch.randperm(num_channels))
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
return F.permute_channels(inpt, params["permutation"])
class RandomPhotometricDistort(Transform): class RandomPhotometricDistort(Transform):
"""[BETA] Randomly distorts the image or video as used in `SSD: Single Shot """[BETA] Randomly distorts the image or video as used in `SSD: Single Shot
MultiBox Detector <https://arxiv.org/abs/1512.02325>`_. MultiBox Detector <https://arxiv.org/abs/1512.02325>`_.
...@@ -241,21 +261,6 @@ class RandomPhotometricDistort(Transform): ...@@ -241,21 +261,6 @@ class RandomPhotometricDistort(Transform):
params["channel_permutation"] = torch.randperm(num_channels) if torch.rand(1) < self.p else None params["channel_permutation"] = torch.randperm(num_channels) if torch.rand(1) < self.p else None
return params return params
def _permute_channels(
self, inpt: Union[datapoints._ImageType, datapoints._VideoType], permutation: torch.Tensor
) -> Union[datapoints._ImageType, datapoints._VideoType]:
orig_inpt = inpt
if isinstance(orig_inpt, PIL.Image.Image):
inpt = F.pil_to_tensor(inpt)
# TODO: Find a better fix than as_subclass???
output = inpt[..., permutation, :, :].as_subclass(type(inpt))
if isinstance(orig_inpt, PIL.Image.Image):
output = F.to_image_pil(output)
return output
def _transform( def _transform(
self, inpt: Union[datapoints._ImageType, datapoints._VideoType], params: Dict[str, Any] self, inpt: Union[datapoints._ImageType, datapoints._VideoType], params: Dict[str, Any]
) -> Union[datapoints._ImageType, datapoints._VideoType]: ) -> Union[datapoints._ImageType, datapoints._VideoType]:
...@@ -270,7 +275,7 @@ class RandomPhotometricDistort(Transform): ...@@ -270,7 +275,7 @@ class RandomPhotometricDistort(Transform):
if params["contrast_factor"] is not None and not params["contrast_before"]: if params["contrast_factor"] is not None and not params["contrast_before"]:
inpt = F.adjust_contrast(inpt, contrast_factor=params["contrast_factor"]) inpt = F.adjust_contrast(inpt, contrast_factor=params["contrast_factor"])
if params["channel_permutation"] is not None: if params["channel_permutation"] is not None:
inpt = self._permute_channels(inpt, permutation=params["channel_permutation"]) inpt = F.permute_channels(inpt, permutation=params["channel_permutation"])
return inpt return inpt
......
...@@ -62,6 +62,10 @@ from ._color import ( ...@@ -62,6 +62,10 @@ from ._color import (
invert_image_pil, invert_image_pil,
invert_image_tensor, invert_image_tensor,
invert_video, invert_video,
permute_channels,
permute_channels_image_pil,
permute_channels_image_tensor,
permute_channels_video,
posterize, posterize,
posterize_image_pil, posterize_image_pil,
posterize_image_tensor, posterize_image_tensor,
......
from typing import Union from typing import List, Union
import PIL.Image import PIL.Image
import torch import torch
...@@ -10,6 +10,8 @@ from torchvision.transforms._functional_tensor import _max_value ...@@ -10,6 +10,8 @@ from torchvision.transforms._functional_tensor import _max_value
from torchvision.utils import _log_api_usage_once from torchvision.utils import _log_api_usage_once
from ._misc import _num_value_bits, to_dtype_image_tensor from ._misc import _num_value_bits, to_dtype_image_tensor
from ._type_conversion import pil_to_tensor, to_image_pil
from ._utils import _get_kernel, _register_explicit_noop, _register_kernel_internal from ._utils import _get_kernel, _register_explicit_noop, _register_kernel_internal
...@@ -641,3 +643,64 @@ invert_image_pil = _register_kernel_internal(invert, PIL.Image.Image)(_FP.invert ...@@ -641,3 +643,64 @@ invert_image_pil = _register_kernel_internal(invert, PIL.Image.Image)(_FP.invert
@_register_kernel_internal(invert, datapoints.Video) @_register_kernel_internal(invert, datapoints.Video)
def invert_video(video: torch.Tensor) -> torch.Tensor: def invert_video(video: torch.Tensor) -> torch.Tensor:
return invert_image_tensor(video) return invert_image_tensor(video)
@_register_explicit_noop(datapoints.BoundingBoxes, datapoints.Mask)
def permute_channels(inpt: datapoints._InputTypeJIT, permutation: List[int]) -> datapoints._InputTypeJIT:
"""Permute the channels of the input according to the given permutation.
This function supports plain :class:`~torch.Tensor`'s, :class:`PIL.Image.Image`'s, and
:class:`torchvision.datapoints.Image` and :class:`torchvision.datapoints.Video`.
Example:
>>> rgb_image = torch.rand(3, 256, 256)
>>> bgr_image = F.permutate_channels(rgb_image, permutation=[2, 1, 0])
Args:
permutation (List[int]): Valid permutation of the input channel indices. The index of the element determines the
channel index in the input and the value determines the channel index in the output. For example,
``permutation=[2, 0 , 1]``
- takes ``ìnpt[..., 0, :, :]`` and puts it at ``output[..., 2, :, :]``,
- takes ``ìnpt[..., 1, :, :]`` and puts it at ``output[..., 0, :, :]``, and
- takes ``ìnpt[..., 2, :, :]`` and puts it at ``output[..., 1, :, :]``.
Raises:
ValueError: If ``len(permutation)`` doesn't match the number of channels in the input.
"""
if torch.jit.is_scripting():
return permute_channels_image_tensor(inpt, permutation=permutation)
_log_api_usage_once(permute_channels)
kernel = _get_kernel(permute_channels, type(inpt))
return kernel(inpt, permutation=permutation)
@_register_kernel_internal(permute_channels, torch.Tensor)
@_register_kernel_internal(permute_channels, datapoints.Image)
def permute_channels_image_tensor(image: torch.Tensor, permutation: List[int]) -> torch.Tensor:
shape = image.shape
num_channels, height, width = shape[-3:]
if len(permutation) != num_channels:
raise ValueError(
f"Length of permutation does not match number of channels: " f"{len(permutation)} != {num_channels}"
)
if image.numel() == 0:
return image
image = image.reshape(-1, num_channels, height, width)
image = image[:, permutation, :, :]
return image.reshape(shape)
@_register_kernel_internal(permute_channels, PIL.Image.Image)
def permute_channels_image_pil(image: PIL.Image.Image, permutation: List[int]) -> PIL.Image:
return to_image_pil(permute_channels_image_tensor(pil_to_tensor(image), permutation=permutation))
@_register_kernel_internal(permute_channels, datapoints.Video)
def permute_channels_video(video: torch.Tensor, permutation: List[int]) -> torch.Tensor:
return permute_channels_image_tensor(video, permutation=permutation)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment