Unverified Commit c0ba3ec8 authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

Proto transform cleanup (#6408)

* fix TenCrop

* use dispatchers for RandomPhotometricDistort

* add convert_color_space dispatcher and use it in conversion transforms

* fix convert_color_space naming scheme

* add to_color_space method to Image feature

* remove TODO from BoundingBox.to_format()

* fix test

* fix imports

* fix passthrough

* remove apply_recursively in favor of pytree

* refactor BatchMultiCrop
parent 94960fe1
......@@ -200,7 +200,7 @@ class TestSmoke:
@parametrize(
[
(
transforms.ConvertImageColorSpace(color_space=new_color_space, old_color_space=old_color_space),
transforms.ConvertColorSpace(color_space=new_color_space, old_color_space=old_color_space),
itertools.chain.from_iterable(
[
fn(color_spaces=[old_color_space])
......@@ -223,7 +223,7 @@ class TestSmoke:
)
]
)
def test_convert_image_color_space(self, transform, input):
def test_convertolor_space(self, transform, input):
transform(input)
......
......@@ -60,17 +60,13 @@ class BoundingBox(_Feature):
)
def to_format(self, format: Union[str, BoundingBoxFormat]) -> BoundingBox:
# TODO: this is useful for developing and debugging but we should remove or at least revisit this before we
# promote this out of the prototype state
# import at runtime to avoid cyclic imports
from torchvision.prototype.transforms.functional import convert_bounding_box_format
from torchvision.prototype.transforms import functional as _F
if isinstance(format, str):
format = BoundingBoxFormat.from_str(format.upper())
return BoundingBox.new_like(
self, convert_bounding_box_format(self, old_format=self.format, new_format=format), format=format
self, _F.convert_bounding_box_format(self, old_format=self.format, new_format=format), format=format
)
def horizontal_flip(self) -> BoundingBox:
......
......@@ -99,6 +99,20 @@ class Image(_Feature):
else:
return ColorSpace.OTHER
def to_color_space(self, color_space: Union[str, ColorSpace], copy: bool = True) -> Image:
from torchvision.prototype.transforms import functional as _F
if isinstance(color_space, str):
color_space = ColorSpace.from_str(color_space.upper())
return Image.new_like(
self,
_F.convert_color_space_image_tensor(
self, old_color_space=self.color_space, new_color_space=color_space, copy=copy
),
color_space=color_space,
)
def show(self) -> None:
# TODO: this is useful for developing and debugging but we should remove or at least revisit this before we
# promote this out of the prototype state
......
......@@ -33,7 +33,7 @@ from ._geometry import (
ScaleJitter,
TenCrop,
)
from ._meta import ConvertBoundingBoxFormat, ConvertImageColorSpace, ConvertImageDtype
from ._meta import ConvertBoundingBoxFormat, ConvertColorSpace, ConvertImageDtype
from ._misc import GaussianBlur, Identity, Lambda, Normalize, ToDtype
from ._type_conversion import DecodeImage, LabelToOneHot, ToImagePIL, ToImageTensor
......
......@@ -84,30 +84,6 @@ class ColorJitter(Transform):
return output
class _RandomChannelShuffle(Transform):
def _get_params(self, sample: Any) -> Dict[str, Any]:
image = query_image(sample)
num_channels, _, _ = get_image_dimensions(image)
return dict(permutation=torch.randperm(num_channels))
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
if not (isinstance(inpt, (features.Image, PIL.Image.Image)) or is_simple_tensor(inpt)):
return inpt
image = inpt
if isinstance(inpt, PIL.Image.Image):
image = _F.pil_to_tensor(image)
output = image[..., params["permutation"], :, :]
if isinstance(inpt, features.Image):
output = features.Image.new_like(inpt, output, color_space=features.ColorSpace.OTHER)
elif isinstance(inpt, PIL.Image.Image):
output = _F.to_pil_image(output)
return output
class RandomPhotometricDistort(Transform):
def __init__(
self,
......@@ -118,35 +94,62 @@ class RandomPhotometricDistort(Transform):
p: float = 0.5,
):
super().__init__()
self._brightness = ColorJitter(brightness=brightness)
self._contrast = ColorJitter(contrast=contrast)
self._hue = ColorJitter(hue=hue)
self._saturation = ColorJitter(saturation=saturation)
self._channel_shuffle = _RandomChannelShuffle()
self.brightness = brightness
self.contrast = contrast
self.hue = hue
self.saturation = saturation
self.p = p
def _get_params(self, sample: Any) -> Dict[str, Any]:
image = query_image(sample)
num_channels, _, _ = get_image_dimensions(image)
return dict(
zip(
["brightness", "contrast1", "saturation", "hue", "contrast2", "channel_shuffle"],
["brightness", "contrast1", "saturation", "hue", "contrast2"],
torch.rand(6) < self.p,
),
contrast_before=torch.rand(()) < 0.5,
channel_permutation=torch.randperm(num_channels) if torch.rand(()) < self.p else None,
)
def _permute_channels(self, inpt: Any, *, permutation: torch.Tensor) -> Any:
if not (isinstance(inpt, (features.Image, PIL.Image.Image)) or is_simple_tensor(inpt)):
return inpt
image = inpt
if isinstance(inpt, PIL.Image.Image):
image = _F.pil_to_tensor(image)
output = image[..., permutation, :, :]
if isinstance(inpt, features.Image):
output = features.Image.new_like(inpt, output, color_space=features.ColorSpace.OTHER)
elif isinstance(inpt, PIL.Image.Image):
output = _F.to_pil_image(output)
return output
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
if params["brightness"]:
inpt = self._brightness(inpt)
inpt = F.adjust_brightness(
inpt, brightness_factor=ColorJitter._generate_value(self.brightness[0], self.brightness[1])
)
if params["contrast1"] and params["contrast_before"]:
inpt = self._contrast(inpt)
if params["saturation"]:
inpt = self._saturation(inpt)
inpt = F.adjust_contrast(
inpt, contrast_factor=ColorJitter._generate_value(self.contrast[0], self.contrast[1])
)
if params["saturation"]:
inpt = self._saturation(inpt)
inpt = F.adjust_saturation(
inpt, saturation_factor=ColorJitter._generate_value(self.saturation[0], self.saturation[1])
)
if params["hue"]:
inpt = F.adjust_hue(inpt, hue_factor=ColorJitter._generate_value(self.hue[0], self.hue[1]))
if params["contrast2"] and not params["contrast_before"]:
inpt = self._contrast(inpt)
if params["channel_shuffle"]:
inpt = self._channel_shuffle(inpt)
inpt = F.adjust_contrast(
inpt, contrast_factor=ColorJitter._generate_value(self.contrast[0], self.contrast[1])
)
if params["channel_permutation"]:
inpt = self._permute_channels(inpt, permutation=params["channel_permutation"])
return inpt
......
......@@ -4,13 +4,13 @@ from typing import Any, Dict, Optional
import numpy as np
import PIL.Image
import torch
import torchvision.prototype.transforms.functional as F
from torchvision.prototype import features
from torchvision.prototype.features import ColorSpace
from torchvision.prototype.transforms import Transform
from torchvision.transforms import functional as _F
from typing_extensions import Literal
from ._meta import ConvertImageColorSpace
from ._transform import _RandomApplyTransform
from ._utils import is_simple_tensor
......@@ -90,13 +90,11 @@ class Grayscale(Transform):
super().__init__()
self.num_output_channels = num_output_channels
self._rgb_to_gray = ConvertImageColorSpace(old_color_space=ColorSpace.RGB, color_space=ColorSpace.GRAY)
self._gray_to_rgb = ConvertImageColorSpace(old_color_space=ColorSpace.GRAY, color_space=ColorSpace.RGB)
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
output = self._rgb_to_gray(inpt)
output = F.convert_color_space(inpt, color_space=ColorSpace.GRAY, old_color_space=ColorSpace.RGB)
if self.num_output_channels == 3:
output = self._gray_to_rgb(output)
output = F.convert_color_space(inpt, color_space=ColorSpace.RGB, old_color_space=ColorSpace.GRAY)
return output
......@@ -115,8 +113,7 @@ class RandomGrayscale(_RandomApplyTransform):
)
super().__init__(p=p)
self._rgb_to_gray = ConvertImageColorSpace(old_color_space=ColorSpace.RGB, color_space=ColorSpace.GRAY)
self._gray_to_rgb = ConvertImageColorSpace(old_color_space=ColorSpace.GRAY, color_space=ColorSpace.RGB)
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
return self._gray_to_rgb(self._rgb_to_gray(inpt))
output = F.convert_color_space(inpt, color_space=ColorSpace.GRAY, old_color_space=ColorSpace.RGB)
return F.convert_color_space(output, color_space=ColorSpace.RGB, old_color_space=ColorSpace.GRAY)
import collections.abc
import math
import numbers
import warnings
......@@ -180,9 +179,9 @@ class TenCrop(Transform):
output = F.ten_crop_image_tensor(inpt, self.size, vertical_flip=self.vertical_flip)
return MultiCropResult(features.Image.new_like(inpt, o) for o in output)
elif is_simple_tensor(inpt):
return MultiCropResult(F.ten_crop_image_tensor(inpt, self.size))
return MultiCropResult(F.ten_crop_image_tensor(inpt, self.size, vertical_flip=self.vertical_flip))
elif isinstance(inpt, PIL.Image.Image):
return MultiCropResult(F.ten_crop_image_pil(inpt, self.size))
return MultiCropResult(F.ten_crop_image_pil(inpt, self.size, vertical_flip=self.vertical_flip))
else:
return inpt
......@@ -194,31 +193,19 @@ class TenCrop(Transform):
class BatchMultiCrop(Transform):
def forward(self, *inputs: Any) -> Any:
# This is basically the functionality of `torchvision.prototype.utils._internal.apply_recursively` with one
# significant difference:
# Since we need multiple images to batch them together, we need to explicitly exclude `MultiCropResult` from
# the sequence case.
def apply_recursively(obj: Any) -> Any:
if isinstance(obj, MultiCropResult):
crops = obj
if isinstance(obj[0], PIL.Image.Image):
crops = [pil_to_tensor(crop) for crop in crops] # type: ignore[assignment]
_transformed_types = (MultiCropResult,)
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
crops = inpt
if isinstance(inpt[0], PIL.Image.Image):
crops = [pil_to_tensor(crop) for crop in crops]
batch = torch.stack(crops)
if isinstance(obj[0], features.Image):
batch = features.Image.new_like(obj[0], batch)
if isinstance(inpt[0], features.Image):
batch = features.Image.new_like(inpt[0], batch)
return batch
elif isinstance(obj, collections.abc.Sequence) and not isinstance(obj, str):
return [apply_recursively(item) for item in obj]
elif isinstance(obj, collections.abc.Mapping):
return {key: apply_recursively(item) for key, item in obj.items()}
else:
return obj
return apply_recursively(inputs if len(inputs) > 1 else inputs[0])
def _check_fill_arg(fill: Union[int, float, Sequence[int], Sequence[float]]) -> None:
......
from typing import Any, Dict, Optional, Union
import PIL.Image
import torch
from torchvision.prototype import features
from torchvision.prototype.transforms import functional as F, Transform
......@@ -39,11 +40,15 @@ class ConvertImageDtype(Transform):
return inpt
class ConvertImageColorSpace(Transform):
class ConvertColorSpace(Transform):
# F.convert_color_space does NOT handle `_Feature`'s in general
_transformed_types = (torch.Tensor, features.Image, PIL.Image.Image)
def __init__(
self,
color_space: Union[str, features.ColorSpace],
old_color_space: Optional[Union[str, features.ColorSpace]] = None,
copy: bool = True,
) -> None:
super().__init__()
......@@ -55,23 +60,9 @@ class ConvertImageColorSpace(Transform):
old_color_space = features.ColorSpace.from_str(old_color_space)
self.old_color_space = old_color_space
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
if isinstance(inpt, features.Image):
output = F.convert_image_color_space_tensor(
inpt, old_color_space=inpt.color_space, new_color_space=self.color_space
)
return features.Image.new_like(inpt, output, color_space=self.color_space)
elif is_simple_tensor(inpt):
if self.old_color_space is None:
raise RuntimeError(
f"In order to convert simple tensor images, `{type(self).__name__}(...)` "
f"needs to be constructed with the `old_color_space=...` parameter."
)
self.copy = copy
return F.convert_image_color_space_tensor(
inpt, old_color_space=self.old_color_space, new_color_space=self.color_space
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
return F.convert_color_space(
inpt, color_space=self.color_space, old_color_space=self.old_color_space, copy=self.copy
)
elif isinstance(inpt, PIL.Image.Image):
return F.convert_image_color_space_pil(inpt, color_space=self.color_space)
else:
return inpt
from torchvision.transforms import InterpolationMode # usort: skip
from ._meta import (
convert_bounding_box_format,
convert_image_color_space_tensor,
convert_image_color_space_pil,
convert_color_space_image_tensor,
convert_color_space_image_pil,
convert_color_space,
) # usort: skip
from ._augment import erase_image_pil, erase_image_tensor
......
from typing import Optional, Tuple
from typing import Any, Optional, Tuple
import PIL.Image
import torch
from torchvision.prototype.features import BoundingBoxFormat, ColorSpace
from torchvision.prototype.features import BoundingBoxFormat, ColorSpace, Image
from torchvision.transforms import functional_pil as _FP, functional_tensor as _FT
get_dimensions_image_tensor = _FT.get_dimensions
......@@ -91,7 +91,7 @@ def _gray_to_rgb(grayscale: torch.Tensor) -> torch.Tensor:
_rgb_to_gray = _FT.rgb_to_grayscale
def convert_image_color_space_tensor(
def convert_color_space_image_tensor(
image: torch.Tensor, old_color_space: ColorSpace, new_color_space: ColorSpace, copy: bool = True
) -> torch.Tensor:
if new_color_space == old_color_space:
......@@ -141,7 +141,7 @@ _COLOR_SPACE_TO_PIL_MODE = {
}
def convert_image_color_space_pil(
def convert_color_space_image_pil(
image: PIL.Image.Image, color_space: ColorSpace, copy: bool = True
) -> PIL.Image.Image:
old_mode = image.mode
......@@ -154,3 +154,21 @@ def convert_image_color_space_pil(
return image
return image.convert(new_mode)
def convert_color_space(
inpt: Any, *, color_space: ColorSpace, old_color_space: Optional[ColorSpace] = None, copy: bool = True
) -> Any:
if isinstance(inpt, Image):
return inpt.to_color_space(color_space, copy=copy)
elif isinstance(inpt, PIL.Image.Image):
return convert_color_space_image_pil(inpt, color_space, copy=copy)
else:
if old_color_space is None:
raise RuntimeError(
"In order to convert the color space of simple tensor images, "
"the `old_color_space=...` parameter needs to be passed."
)
return convert_color_space_image_tensor(
inpt, old_color_space=old_color_space, new_color_space=color_space, copy=copy
)
......@@ -14,7 +14,6 @@ __all__ = [
"add_suggestion",
"fromfile",
"ReadOnlyTensorBuffer",
"apply_recursively",
"query_recursively",
]
......@@ -128,17 +127,6 @@ class ReadOnlyTensorBuffer:
return self._memory[slice(cursor, self.seek(offset, whence))].tobytes()
def apply_recursively(fn: Callable, obj: Any) -> Any:
# We explicitly exclude str's here since they are self-referential and would cause an infinite recursion loop:
# "a" == "a"[0][0]...
if isinstance(obj, collections.abc.Sequence) and not isinstance(obj, str):
return [apply_recursively(fn, item) for item in obj]
elif isinstance(obj, collections.abc.Mapping):
return {key: apply_recursively(fn, item) for key, item in obj.items()}
else:
return fn(obj)
def query_recursively(
fn: Callable[[Tuple[Any, ...], Any], Optional[D]], obj: Any, *, id: Tuple[Any, ...] = ()
) -> Iterator[D]:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment