"tests/git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "9d7c08f95e79a56a68cf101ccd1b3983ee3d2743"
Unverified Commit becaba0e authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

Cleanup of prototype transforms (#6492)

* fix passtrough on transforms and add dispatchers for five and ten crop

* Revert "cleanup prototype auto augment transforms (#6463)"

This reverts commit d8025b9a

.

* use legacy kernels in deprecated Grayscale and RandomGrayscale transforms

* fix default type for Lambda transform

* fix default type for ToDtype transform

* move simple_tensor to features module

* [skip ci]

* Revert "move simple_tensor to features module"

This reverts commit 7043b6ee3e3b1f6541371a4f2442cfc1fd664e4a.

* cleanup

* reinstate valid AA changes

* address review

* Fix linter
Co-authored-by: default avatarVasilis Vryniotis <datumbox@users.noreply.github.com>
parent 79098ad9
import math
import numbers
from typing import Any, Callable, cast, Dict, List, Optional, Sequence, Tuple, TypeVar, Union
from typing import Any, Callable, cast, Dict, List, Optional, Sequence, Tuple, Type, TypeVar, Union
import PIL.Image
import torch
from torch.utils._pytree import tree_flatten, tree_unflatten
from torchvision.prototype import features
from torchvision.prototype.transforms import functional as F, Transform
from torchvision.transforms.autoaugment import AutoAugmentPolicy
from torchvision.transforms.functional import InterpolationMode, pil_to_tensor, to_pil_image
from ._utils import is_simple_tensor, query_chw
from ._utils import _isinstance, get_chw, is_simple_tensor
K = TypeVar("K")
V = TypeVar("V")
......@@ -35,9 +36,31 @@ class _AutoAugmentBase(Transform):
key = keys[int(torch.randint(len(keys), ()))]
return key, dct[key]
def _get_params(self, sample: Any) -> Dict[str, Any]:
_, height, width = query_chw(sample)
return dict(height=height, width=width)
def _extract_image(
self,
sample: Any,
unsupported_types: Tuple[Type, ...] = (features.BoundingBox, features.SegmentationMask),
) -> Tuple[int, Union[PIL.Image.Image, torch.Tensor, features.Image]]:
sample_flat, _ = tree_flatten(sample)
images = []
for id, inpt in enumerate(sample_flat):
if _isinstance(inpt, (features.Image, PIL.Image.Image, is_simple_tensor)):
images.append((id, inpt))
elif isinstance(inpt, unsupported_types):
raise TypeError(f"Inputs of type {type(inpt).__name__} are not supported by {type(self).__name__}()")
if not images:
raise TypeError("Found no image in the sample.")
if len(images) > 1:
raise TypeError(
f"Auto augment transformations are only properly defined for a single image, but found {len(images)}."
)
return images[0]
def _put_into_sample(self, sample: Any, id: int, item: Any) -> Any:
sample_flat, spec = tree_flatten(sample)
sample_flat[id] = item
return tree_unflatten(sample_flat, spec)
def _apply_image_transform(
self,
......@@ -242,22 +265,21 @@ class AutoAugment(_AutoAugmentBase):
else:
raise ValueError(f"The provided policy {policy} is not recognized.")
def _get_params(self, sample: Any) -> Dict[str, Any]:
params = super(AutoAugment, self)._get_params(sample)
params["policy"] = self._policies[int(torch.randint(len(self._policies), ()))]
return params
def forward(self, *inputs: Any) -> Any:
sample = inputs if len(inputs) > 1 else inputs[0]
id, image = self._extract_image(sample)
num_channels, height, width = get_chw(image)
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
if not (isinstance(inpt, (features.Image, PIL.Image.Image)) or is_simple_tensor(inpt)):
return inpt
policy = self._policies[int(torch.randint(len(self._policies), ()))]
for transform_id, probability, magnitude_idx in params["policy"]:
for transform_id, probability, magnitude_idx in policy:
if not torch.rand(()) <= probability:
continue
magnitudes_fn, signed = self._AUGMENTATION_SPACE[transform_id]
magnitudes = magnitudes_fn(10, params["height"], params["width"])
magnitudes = magnitudes_fn(10, height, width)
if magnitudes is not None:
magnitude = float(magnitudes[magnitude_idx])
if signed and torch.rand(()) <= 0.5:
......@@ -265,11 +287,11 @@ class AutoAugment(_AutoAugmentBase):
else:
magnitude = 0.0
inpt = self._apply_image_transform(
inpt, transform_id, magnitude, interpolation=self.interpolation, fill=self.fill
image = self._apply_image_transform(
image, transform_id, magnitude, interpolation=self.interpolation, fill=self.fill
)
return inpt
return self._put_into_sample(sample, id, image)
class RandAugment(_AutoAugmentBase):
......@@ -315,14 +337,16 @@ class RandAugment(_AutoAugmentBase):
self.magnitude = magnitude
self.num_magnitude_bins = num_magnitude_bins
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
if not (isinstance(inpt, (features.Image, PIL.Image.Image)) or is_simple_tensor(inpt)):
return inpt
def forward(self, *inputs: Any) -> Any:
sample = inputs if len(inputs) > 1 else inputs[0]
id, image = self._extract_image(sample)
num_channels, height, width = get_chw(image)
for _ in range(self.num_ops):
transform_id, (magnitudes_fn, signed) = self._get_random_item(self._AUGMENTATION_SPACE)
magnitudes = magnitudes_fn(self.num_magnitude_bins, params["height"], params["width"])
magnitudes = magnitudes_fn(self.num_magnitude_bins, height, width)
if magnitudes is not None:
magnitude = float(magnitudes[int(torch.randint(self.num_magnitude_bins, ()))])
if signed and torch.rand(()) <= 0.5:
......@@ -330,11 +354,11 @@ class RandAugment(_AutoAugmentBase):
else:
magnitude = 0.0
inpt = self._apply_image_transform(
inpt, transform_id, magnitude, interpolation=self.interpolation, fill=self.fill
image = self._apply_image_transform(
image, transform_id, magnitude, interpolation=self.interpolation, fill=self.fill
)
return inpt
return self._put_into_sample(sample, id, image)
class TrivialAugmentWide(_AutoAugmentBase):
......@@ -370,13 +394,15 @@ class TrivialAugmentWide(_AutoAugmentBase):
super().__init__(interpolation=interpolation, fill=fill)
self.num_magnitude_bins = num_magnitude_bins
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
if not (isinstance(inpt, (features.Image, PIL.Image.Image)) or is_simple_tensor(inpt)):
return inpt
def forward(self, *inputs: Any) -> Any:
sample = inputs if len(inputs) > 1 else inputs[0]
id, image = self._extract_image(sample)
num_channels, height, width = get_chw(image)
transform_id, (magnitudes_fn, signed) = self._get_random_item(self._AUGMENTATION_SPACE)
magnitudes = magnitudes_fn(self.num_magnitude_bins, params["height"], params["width"])
magnitudes = magnitudes_fn(self.num_magnitude_bins, height, width)
if magnitudes is not None:
magnitude = float(magnitudes[int(torch.randint(self.num_magnitude_bins, ()))])
if signed and torch.rand(()) <= 0.5:
......@@ -384,9 +410,10 @@ class TrivialAugmentWide(_AutoAugmentBase):
else:
magnitude = 0.0
return self._apply_image_transform(
inpt, transform_id, magnitude, interpolation=self.interpolation, fill=self.fill
image = self._apply_image_transform(
image, transform_id, magnitude, interpolation=self.interpolation, fill=self.fill
)
return self._put_into_sample(sample, id, image)
class AugMix(_AutoAugmentBase):
......@@ -438,13 +465,15 @@ class AugMix(_AutoAugmentBase):
# Must be on a separate method so that we can overwrite it in tests.
return torch._sample_dirichlet(params)
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
if isinstance(inpt, features.Image) or is_simple_tensor(inpt):
image = inpt
elif isinstance(inpt, PIL.Image.Image):
image = pil_to_tensor(inpt)
else:
return inpt
def forward(self, *inputs: Any) -> Any:
sample = inputs if len(inputs) > 1 else inputs[0]
id, orig_image = self._extract_image(sample)
num_channels, height, width = get_chw(orig_image)
if isinstance(orig_image, torch.Tensor):
image = orig_image
else: # isinstance(inpt, PIL.Image.Image):
image = pil_to_tensor(orig_image)
augmentation_space = self._AUGMENTATION_SPACE if self.all_ops else self._PARTIAL_AUGMENTATION_SPACE
......@@ -470,7 +499,7 @@ class AugMix(_AutoAugmentBase):
for _ in range(depth):
transform_id, (magnitudes_fn, signed) = self._get_random_item(augmentation_space)
magnitudes = magnitudes_fn(self._PARAMETER_MAX, params["height"], params["width"])
magnitudes = magnitudes_fn(self._PARAMETER_MAX, height, width)
if magnitudes is not None:
magnitude = float(magnitudes[int(torch.randint(self.severity, ()))])
if signed and torch.rand(()) <= 0.5:
......@@ -484,9 +513,9 @@ class AugMix(_AutoAugmentBase):
mix.add_(combined_weights[:, i].view(batch_dims) * aug)
mix = mix.view(orig_dims).to(dtype=image.dtype)
if isinstance(inpt, features.Image):
mix = features.Image.new_like(inpt, mix)
elif isinstance(inpt, PIL.Image.Image):
if isinstance(orig_image, features.Image):
mix = features.Image.new_like(orig_image, mix)
elif isinstance(orig_image, PIL.Image.Image):
mix = to_pil_image(mix)
return mix
return self._put_into_sample(sample, id, mix)
......@@ -4,15 +4,14 @@ from typing import Any, Dict, Optional
import numpy as np
import PIL.Image
import torch
import torchvision.prototype.transforms.functional as F
from torchvision.prototype import features
from torchvision.prototype.features import ColorSpace
from torchvision.prototype.transforms import Transform
from torchvision.transforms import functional as _F
from typing_extensions import Literal
from ._transform import _RandomApplyTransform
from ._utils import is_simple_tensor
from ._utils import is_simple_tensor, query_chw
class ToTensor(Transform):
......@@ -59,6 +58,8 @@ class ToPILImage(Transform):
class Grayscale(Transform):
_transformed_types = (features.Image, PIL.Image.Image, is_simple_tensor)
def __init__(self, num_output_channels: Literal[1, 3] = 1) -> None:
deprecation_msg = (
f"The transform `Grayscale(num_output_channels={num_output_channels})` "
......@@ -81,13 +82,12 @@ class Grayscale(Transform):
self.num_output_channels = num_output_channels
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
output = F.convert_color_space(inpt, color_space=ColorSpace.GRAY, old_color_space=ColorSpace.RGB)
if self.num_output_channels == 3:
output = F.convert_color_space(inpt, color_space=ColorSpace.RGB, old_color_space=ColorSpace.GRAY)
return output
return _F.rgb_to_grayscale(inpt, num_output_channels=self.num_output_channels)
class RandomGrayscale(_RandomApplyTransform):
_transformed_types = (features.Image, PIL.Image.Image, is_simple_tensor)
def __init__(self, p: float = 0.1) -> None:
warnings.warn(
"The transform `RandomGrayscale(p=...)` is deprecated and will be removed in a future release. "
......@@ -103,6 +103,9 @@ class RandomGrayscale(_RandomApplyTransform):
super().__init__(p=p)
def _get_params(self, sample: Any) -> Dict[str, Any]:
num_input_channels, _, _ = query_chw(sample)
return dict(num_input_channels=num_input_channels)
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
output = F.convert_color_space(inpt, color_space=ColorSpace.GRAY, old_color_space=ColorSpace.RGB)
return F.convert_color_space(output, color_space=ColorSpace.RGB, old_color_space=ColorSpace.GRAY)
return _F.rgb_to_grayscale(inpt, num_output_channels=params["num_input_channels"])
......@@ -156,22 +156,14 @@ class FiveCrop(Transform):
torch.Size([5])
"""
_transformed_types = (features.Image, PIL.Image.Image, is_simple_tensor)
def __init__(self, size: Union[int, Sequence[int]]) -> None:
super().__init__()
self.size = _setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.")
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
# TODO: returning a list is technically BC breaking since FiveCrop returned a tuple before. We switched to a
# list here to align it with TenCrop.
if isinstance(inpt, features.Image):
output = F.five_crop_image_tensor(inpt, self.size)
return tuple(features.Image.new_like(inpt, o) for o in output)
elif is_simple_tensor(inpt):
return F.five_crop_image_tensor(inpt, self.size)
elif isinstance(inpt, PIL.Image.Image):
return F.five_crop_image_pil(inpt, self.size)
else:
return inpt
return F.five_crop(inpt, self.size)
def forward(self, *inputs: Any) -> Any:
sample = inputs if len(inputs) > 1 else inputs[0]
......@@ -185,21 +177,15 @@ class TenCrop(Transform):
See :class:`~torchvision.prototype.transforms.FiveCrop` for an example.
"""
_transformed_types = (features.Image, PIL.Image.Image, is_simple_tensor)
def __init__(self, size: Union[int, Sequence[int]], vertical_flip: bool = False) -> None:
super().__init__()
self.size = _setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.")
self.vertical_flip = vertical_flip
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
if isinstance(inpt, features.Image):
output = F.ten_crop_image_tensor(inpt, self.size, vertical_flip=self.vertical_flip)
return [features.Image.new_like(inpt, o) for o in output]
elif is_simple_tensor(inpt):
return F.ten_crop_image_tensor(inpt, self.size, vertical_flip=self.vertical_flip)
elif isinstance(inpt, PIL.Image.Image):
return F.ten_crop_image_pil(inpt, self.size, vertical_flip=self.vertical_flip)
else:
return inpt
return F.ten_crop(inpt, self.size, vertical_flip=self.vertical_flip)
def forward(self, *inputs: Any) -> Any:
sample = inputs if len(inputs) > 1 else inputs[0]
......
......@@ -22,7 +22,7 @@ class Lambda(Transform):
def __init__(self, fn: Callable[[Any], Any], *types: Type):
super().__init__()
self.fn = fn
self.types = types
self.types = types or (object,)
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
if type(inpt) in self.types:
......@@ -137,7 +137,7 @@ class GaussianBlur(Transform):
class ToDtype(Lambda):
def __init__(self, dtype: torch.dtype, *types: Type) -> None:
self.dtype = dtype
super().__init__(functools.partial(torch.Tensor.to, dtype=dtype), *types)
super().__init__(functools.partial(torch.Tensor.to, dtype=dtype), *types or (torch.Tensor,))
def extra_repr(self) -> str:
return ", ".join([f"dtype={self.dtype}", f"types={[type.__name__ for type in self.types]}"])
......
......@@ -65,6 +65,7 @@ from ._geometry import (
elastic_image_tensor,
elastic_segmentation_mask,
elastic_transform,
five_crop,
five_crop_image_pil,
five_crop_image_tensor,
horizontal_flip,
......@@ -97,6 +98,7 @@ from ._geometry import (
rotate_image_pil,
rotate_image_tensor,
rotate_segmentation_mask,
ten_crop,
ten_crop_image_pil,
ten_crop_image_tensor,
vertical_flip,
......
......@@ -1078,6 +1078,17 @@ def five_crop_image_pil(
return tl, tr, bl, br, center
def five_crop(inpt: DType, size: List[int]) -> Tuple[DType, DType, DType, DType, DType]:
# TODO: consider breaking BC here to return List[DType] to align this op with `ten_crop`
if isinstance(inpt, torch.Tensor):
output = five_crop_image_tensor(inpt, size)
if isinstance(inpt, features.Image):
output = tuple(features.Image.new_like(inpt, item) for item in output) # type: ignore[assignment]
return output
else: # isinstance(inpt, PIL.Image.Image):
return five_crop_image_pil(inpt, size)
def ten_crop_image_tensor(img: torch.Tensor, size: List[int], vertical_flip: bool = False) -> List[torch.Tensor]:
tl, tr, bl, br, center = five_crop_image_tensor(img, size)
......@@ -1102,3 +1113,13 @@ def ten_crop_image_pil(img: PIL.Image.Image, size: List[int], vertical_flip: boo
tl_flip, tr_flip, bl_flip, br_flip, center_flip = five_crop_image_pil(img, size)
return [tl, tr, bl, br, center, tl_flip, tr_flip, bl_flip, br_flip, center_flip]
def ten_crop(inpt: DType, size: List[int], *, vertical_flip: bool = False) -> List[DType]:
if isinstance(inpt, torch.Tensor):
output = ten_crop_image_tensor(inpt, size, vertical_flip=vertical_flip)
if isinstance(inpt, features.Image):
output = [features.Image.new_like(inpt, item) for item in output]
return output
else: # isinstance(inpt, PIL.Image.Image):
return ten_crop_image_pil(inpt, size, vertical_flip=vertical_flip)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment