Unverified Commit 5d8d61ac authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

add PermuteChannels transform (#7624)

parent 2ab937a0
......@@ -155,6 +155,7 @@ Color
ColorJitter
v2.ColorJitter
v2.RandomChannelPermutation
v2.RandomPhotometricDistort
Grayscale
v2.Grayscale
......
......@@ -124,6 +124,7 @@ class TestSmoke:
(transforms.RandomEqualize(p=1.0), None),
(transforms.RandomGrayscale(p=1.0), None),
(transforms.RandomInvert(p=1.0), None),
(transforms.RandomChannelPermutation(), None),
(transforms.RandomPhotometricDistort(p=1.0), None),
(transforms.RandomPosterize(bits=4, p=1.0), None),
(transforms.RandomSolarize(threshold=0.5, p=1.0), None),
......
......@@ -2280,3 +2280,61 @@ class TestGetKernel:
_register_kernel_internal(F.resize, MyDatapoint, datapoint_wrapper=False)(resize_my_datapoint)
assert _get_kernel(F.resize, MyDatapoint) is resize_my_datapoint
class TestPermuteChannels:
_DEFAULT_PERMUTATION = [2, 0, 1]
@pytest.mark.parametrize(
("kernel", "make_input"),
[
(F.permute_channels_image_tensor, make_image_tensor),
# FIXME
# check_kernel does not support PIL kernel, but it should
(F.permute_channels_image_tensor, make_image),
(F.permute_channels_video, make_video),
],
)
@pytest.mark.parametrize("dtype", [torch.float32, torch.uint8])
@pytest.mark.parametrize("device", cpu_and_cuda())
def test_kernel(self, kernel, make_input, dtype, device):
check_kernel(kernel, make_input(dtype=dtype, device=device), permutation=self._DEFAULT_PERMUTATION)
@pytest.mark.parametrize(
("kernel", "make_input"),
[
(F.permute_channels_image_tensor, make_image_tensor),
(F.permute_channels_image_pil, make_image_pil),
(F.permute_channels_image_tensor, make_image),
(F.permute_channels_video, make_video),
],
)
def test_dispatcher(self, kernel, make_input):
check_dispatcher(F.permute_channels, kernel, make_input(), permutation=self._DEFAULT_PERMUTATION)
@pytest.mark.parametrize(
("kernel", "input_type"),
[
(F.permute_channels_image_tensor, torch.Tensor),
(F.permute_channels_image_pil, PIL.Image.Image),
(F.permute_channels_image_tensor, datapoints.Image),
(F.permute_channels_video, datapoints.Video),
],
)
def test_dispatcher_signature(self, kernel, input_type):
check_dispatcher_kernel_signature_match(F.permute_channels, kernel=kernel, input_type=input_type)
def reference_image_correctness(self, image, permutation):
channel_images = image.split(1, dim=-3)
permuted_channel_images = [channel_images[channel_idx] for channel_idx in permutation]
return datapoints.Image(torch.concat(permuted_channel_images, dim=-3))
@pytest.mark.parametrize("permutation", [[2, 0, 1], [1, 2, 0], [2, 0, 1], [0, 1, 2]])
@pytest.mark.parametrize("batch_dims", [(), (2,), (2, 1)])
def test_image_correctness(self, permutation, batch_dims):
image = make_image(batch_dims=batch_dims)
actual = F.permute_channels(image, permutation=permutation)
expected = self.reference_image_correctness(image, permutation=permutation)
torch.testing.assert_close(actual, expected)
......@@ -11,6 +11,7 @@ from ._color import (
Grayscale,
RandomAdjustSharpness,
RandomAutocontrast,
RandomChannelPermutation,
RandomEqualize,
RandomGrayscale,
RandomInvert,
......
......@@ -177,7 +177,27 @@ class ColorJitter(Transform):
return output
# TODO: This class seems to be untested
class RandomChannelPermutation(Transform):
"""[BETA] Randomly permute the channels of an image or video
.. v2betastatus:: RandomChannelPermutation transform
"""
_transformed_types = (
datapoints.Image,
PIL.Image.Image,
is_simple_tensor,
datapoints.Video,
)
def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
num_channels, *_ = query_chw(flat_inputs)
return dict(permutation=torch.randperm(num_channels))
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
return F.permute_channels(inpt, params["permutation"])
class RandomPhotometricDistort(Transform):
"""[BETA] Randomly distorts the image or video as used in `SSD: Single Shot
MultiBox Detector <https://arxiv.org/abs/1512.02325>`_.
......@@ -241,21 +261,6 @@ class RandomPhotometricDistort(Transform):
params["channel_permutation"] = torch.randperm(num_channels) if torch.rand(1) < self.p else None
return params
def _permute_channels(
self, inpt: Union[datapoints._ImageType, datapoints._VideoType], permutation: torch.Tensor
) -> Union[datapoints._ImageType, datapoints._VideoType]:
orig_inpt = inpt
if isinstance(orig_inpt, PIL.Image.Image):
inpt = F.pil_to_tensor(inpt)
# TODO: Find a better fix than as_subclass???
output = inpt[..., permutation, :, :].as_subclass(type(inpt))
if isinstance(orig_inpt, PIL.Image.Image):
output = F.to_image_pil(output)
return output
def _transform(
self, inpt: Union[datapoints._ImageType, datapoints._VideoType], params: Dict[str, Any]
) -> Union[datapoints._ImageType, datapoints._VideoType]:
......@@ -270,7 +275,7 @@ class RandomPhotometricDistort(Transform):
if params["contrast_factor"] is not None and not params["contrast_before"]:
inpt = F.adjust_contrast(inpt, contrast_factor=params["contrast_factor"])
if params["channel_permutation"] is not None:
inpt = self._permute_channels(inpt, permutation=params["channel_permutation"])
inpt = F.permute_channels(inpt, permutation=params["channel_permutation"])
return inpt
......
......@@ -62,6 +62,10 @@ from ._color import (
invert_image_pil,
invert_image_tensor,
invert_video,
permute_channels,
permute_channels_image_pil,
permute_channels_image_tensor,
permute_channels_video,
posterize,
posterize_image_pil,
posterize_image_tensor,
......
from typing import Union
from typing import List, Union
import PIL.Image
import torch
......@@ -10,6 +10,8 @@ from torchvision.transforms._functional_tensor import _max_value
from torchvision.utils import _log_api_usage_once
from ._misc import _num_value_bits, to_dtype_image_tensor
from ._type_conversion import pil_to_tensor, to_image_pil
from ._utils import _get_kernel, _register_explicit_noop, _register_kernel_internal
......@@ -641,3 +643,64 @@ invert_image_pil = _register_kernel_internal(invert, PIL.Image.Image)(_FP.invert
@_register_kernel_internal(invert, datapoints.Video)
def invert_video(video: torch.Tensor) -> torch.Tensor:
return invert_image_tensor(video)
@_register_explicit_noop(datapoints.BoundingBoxes, datapoints.Mask)
def permute_channels(inpt: datapoints._InputTypeJIT, permutation: List[int]) -> datapoints._InputTypeJIT:
"""Permute the channels of the input according to the given permutation.
This function supports plain :class:`~torch.Tensor`'s, :class:`PIL.Image.Image`'s, and
:class:`torchvision.datapoints.Image` and :class:`torchvision.datapoints.Video`.
Example:
>>> rgb_image = torch.rand(3, 256, 256)
>>> bgr_image = F.permutate_channels(rgb_image, permutation=[2, 1, 0])
Args:
permutation (List[int]): Valid permutation of the input channel indices. The index of the element determines the
channel index in the input and the value determines the channel index in the output. For example,
``permutation=[2, 0 , 1]``
- takes ``ìnpt[..., 0, :, :]`` and puts it at ``output[..., 2, :, :]``,
- takes ``ìnpt[..., 1, :, :]`` and puts it at ``output[..., 0, :, :]``, and
- takes ``ìnpt[..., 2, :, :]`` and puts it at ``output[..., 1, :, :]``.
Raises:
ValueError: If ``len(permutation)`` doesn't match the number of channels in the input.
"""
if torch.jit.is_scripting():
return permute_channels_image_tensor(inpt, permutation=permutation)
_log_api_usage_once(permute_channels)
kernel = _get_kernel(permute_channels, type(inpt))
return kernel(inpt, permutation=permutation)
@_register_kernel_internal(permute_channels, torch.Tensor)
@_register_kernel_internal(permute_channels, datapoints.Image)
def permute_channels_image_tensor(image: torch.Tensor, permutation: List[int]) -> torch.Tensor:
shape = image.shape
num_channels, height, width = shape[-3:]
if len(permutation) != num_channels:
raise ValueError(
f"Length of permutation does not match number of channels: " f"{len(permutation)} != {num_channels}"
)
if image.numel() == 0:
return image
image = image.reshape(-1, num_channels, height, width)
image = image[:, permutation, :, :]
return image.reshape(shape)
@_register_kernel_internal(permute_channels, PIL.Image.Image)
def permute_channels_image_pil(image: PIL.Image.Image, permutation: List[int]) -> PIL.Image:
return to_image_pil(permute_channels_image_tensor(pil_to_tensor(image), permutation=permutation))
@_register_kernel_internal(permute_channels, datapoints.Video)
def permute_channels_video(video: torch.Tensor, permutation: List[int]) -> torch.Tensor:
return permute_channels_image_tensor(video, permutation=permutation)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment