Unverified Commit ec1c2a12 authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

port RandomPhotoMetricDistort to prototype transforms (#5663)

parent 6db54fb7
......@@ -4,7 +4,7 @@ from ._transform import Transform # usort: skip
from ._augment import RandomErasing, RandomMixup, RandomCutmix
from ._auto_augment import RandAugment, TrivialAugmentWide, AutoAugment, AugMix
from ._color import ColorJitter
from ._color import ColorJitter, RandomPhotometricDistort
from ._container import Compose, RandomApply, RandomChoice, RandomOrder
from ._geometry import (
Resize,
......
......@@ -6,8 +6,9 @@ import PIL.Image
import torch
from torchvision.prototype import features
from torchvision.prototype.transforms import Transform, functional as F
from torchvision.transforms import functional as _F
from ._utils import is_simple_tensor
from ._utils import is_simple_tensor, get_image_dimensions, query_image
T = TypeVar("T", features.Image, torch.Tensor, PIL.Image.Image)
......@@ -120,5 +121,70 @@ class ColorJitter(Transform):
for transform in params["image_transforms"]:
input = transform(input)
return input
class _RandomChannelShuffle(Transform):
def _get_params(self, sample: Any) -> Dict[str, Any]:
image = query_image(sample)
num_channels, _, _ = get_image_dimensions(image)
return dict(permutation=torch.randperm(num_channels))
def _transform(self, input: Any, params: Dict[str, Any]) -> Any:
if not (isinstance(input, (features.Image, PIL.Image.Image)) or is_simple_tensor(input)):
return input
image = input
if isinstance(input, PIL.Image.Image):
image = _F.pil_to_tensor(image)
output = image[..., params["permutation"], :, :]
if isinstance(input, features.Image):
output = features.Image.new_like(input, output, color_space=features.ColorSpace.OTHER)
elif isinstance(input, PIL.Image.Image):
output = _F.to_pil_image(output)
return output
class RandomPhotometricDistort(Transform):
def __init__(
self,
contrast: Tuple[float, float] = (0.5, 1.5),
saturation: Tuple[float, float] = (0.5, 1.5),
hue: Tuple[float, float] = (-0.05, 0.05),
brightness: Tuple[float, float] = (0.875, 1.125),
p: float = 0.5,
):
super().__init__()
self._brightness = ColorJitter(brightness=brightness)
self._contrast = ColorJitter(contrast=contrast)
self._hue = ColorJitter(hue=hue)
self._saturation = ColorJitter(saturation=saturation)
self._channel_shuffle = _RandomChannelShuffle()
self.p = p
def _get_params(self, sample: Any) -> Dict[str, Any]:
return dict(
zip(
["brightness", "contrast1", "saturation", "hue", "contrast2", "channel_shuffle"],
torch.rand(6) < self.p,
),
contrast_before=torch.rand(()) < 0.5,
)
def _transform(self, input: Any, params: Dict[str, Any]) -> Any:
if params["brightness"]:
input = self._brightness(input)
if params["contrast1"] and params["contrast_before"]:
input = self._contrast(input)
if params["saturation"]:
input = self._saturation(input)
if params["saturation"]:
input = self._saturation(input)
if params["contrast2"] and not params["contrast_before"]:
input = self._contrast(input)
if params["channel_shuffle"]:
input = self._channel_shuffle(input)
return input
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment