"...text-generation-inference.git" did not exist on "90184df79c12ec2aa9111248077e237ca2ba9ee9"
Unverified Commit becaba0e authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

Cleanup of prototype transforms (#6492)

* fix passtrough on transforms and add dispatchers for five and ten crop

* Revert "cleanup prototype auto augment transforms (#6463)"

This reverts commit d8025b9a

.

* use legacy kernels in deprecated Grayscale and RandomGrayscale transforms

* fix default type for Lambda transform

* fix default type for ToDtype transform

* move simple_tensor to features module

* [skip ci]

* Revert "move simple_tensor to features module"

This reverts commit 7043b6ee3e3b1f6541371a4f2442cfc1fd664e4a.

* cleanup

* reinstate valid AA changes

* address review

* Fix linter
Co-authored-by: default avatarVasilis Vryniotis <datumbox@users.noreply.github.com>
parent 79098ad9
import math import math
import numbers import numbers
from typing import Any, Callable, cast, Dict, List, Optional, Sequence, Tuple, TypeVar, Union from typing import Any, Callable, cast, Dict, List, Optional, Sequence, Tuple, Type, TypeVar, Union
import PIL.Image import PIL.Image
import torch import torch
from torch.utils._pytree import tree_flatten, tree_unflatten
from torchvision.prototype import features from torchvision.prototype import features
from torchvision.prototype.transforms import functional as F, Transform from torchvision.prototype.transforms import functional as F, Transform
from torchvision.transforms.autoaugment import AutoAugmentPolicy from torchvision.transforms.autoaugment import AutoAugmentPolicy
from torchvision.transforms.functional import InterpolationMode, pil_to_tensor, to_pil_image from torchvision.transforms.functional import InterpolationMode, pil_to_tensor, to_pil_image
from ._utils import is_simple_tensor, query_chw from ._utils import _isinstance, get_chw, is_simple_tensor
K = TypeVar("K") K = TypeVar("K")
V = TypeVar("V") V = TypeVar("V")
...@@ -35,9 +36,31 @@ class _AutoAugmentBase(Transform): ...@@ -35,9 +36,31 @@ class _AutoAugmentBase(Transform):
key = keys[int(torch.randint(len(keys), ()))] key = keys[int(torch.randint(len(keys), ()))]
return key, dct[key] return key, dct[key]
def _get_params(self, sample: Any) -> Dict[str, Any]: def _extract_image(
_, height, width = query_chw(sample) self,
return dict(height=height, width=width) sample: Any,
unsupported_types: Tuple[Type, ...] = (features.BoundingBox, features.SegmentationMask),
) -> Tuple[int, Union[PIL.Image.Image, torch.Tensor, features.Image]]:
sample_flat, _ = tree_flatten(sample)
images = []
for id, inpt in enumerate(sample_flat):
if _isinstance(inpt, (features.Image, PIL.Image.Image, is_simple_tensor)):
images.append((id, inpt))
elif isinstance(inpt, unsupported_types):
raise TypeError(f"Inputs of type {type(inpt).__name__} are not supported by {type(self).__name__}()")
if not images:
raise TypeError("Found no image in the sample.")
if len(images) > 1:
raise TypeError(
f"Auto augment transformations are only properly defined for a single image, but found {len(images)}."
)
return images[0]
def _put_into_sample(self, sample: Any, id: int, item: Any) -> Any:
sample_flat, spec = tree_flatten(sample)
sample_flat[id] = item
return tree_unflatten(sample_flat, spec)
def _apply_image_transform( def _apply_image_transform(
self, self,
...@@ -242,22 +265,21 @@ class AutoAugment(_AutoAugmentBase): ...@@ -242,22 +265,21 @@ class AutoAugment(_AutoAugmentBase):
else: else:
raise ValueError(f"The provided policy {policy} is not recognized.") raise ValueError(f"The provided policy {policy} is not recognized.")
def _get_params(self, sample: Any) -> Dict[str, Any]: def forward(self, *inputs: Any) -> Any:
params = super(AutoAugment, self)._get_params(sample) sample = inputs if len(inputs) > 1 else inputs[0]
params["policy"] = self._policies[int(torch.randint(len(self._policies), ()))]
return params id, image = self._extract_image(sample)
num_channels, height, width = get_chw(image)
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: policy = self._policies[int(torch.randint(len(self._policies), ()))]
if not (isinstance(inpt, (features.Image, PIL.Image.Image)) or is_simple_tensor(inpt)):
return inpt
for transform_id, probability, magnitude_idx in params["policy"]: for transform_id, probability, magnitude_idx in policy:
if not torch.rand(()) <= probability: if not torch.rand(()) <= probability:
continue continue
magnitudes_fn, signed = self._AUGMENTATION_SPACE[transform_id] magnitudes_fn, signed = self._AUGMENTATION_SPACE[transform_id]
magnitudes = magnitudes_fn(10, params["height"], params["width"]) magnitudes = magnitudes_fn(10, height, width)
if magnitudes is not None: if magnitudes is not None:
magnitude = float(magnitudes[magnitude_idx]) magnitude = float(magnitudes[magnitude_idx])
if signed and torch.rand(()) <= 0.5: if signed and torch.rand(()) <= 0.5:
...@@ -265,11 +287,11 @@ class AutoAugment(_AutoAugmentBase): ...@@ -265,11 +287,11 @@ class AutoAugment(_AutoAugmentBase):
else: else:
magnitude = 0.0 magnitude = 0.0
inpt = self._apply_image_transform( image = self._apply_image_transform(
inpt, transform_id, magnitude, interpolation=self.interpolation, fill=self.fill image, transform_id, magnitude, interpolation=self.interpolation, fill=self.fill
) )
return inpt return self._put_into_sample(sample, id, image)
class RandAugment(_AutoAugmentBase): class RandAugment(_AutoAugmentBase):
...@@ -315,14 +337,16 @@ class RandAugment(_AutoAugmentBase): ...@@ -315,14 +337,16 @@ class RandAugment(_AutoAugmentBase):
self.magnitude = magnitude self.magnitude = magnitude
self.num_magnitude_bins = num_magnitude_bins self.num_magnitude_bins = num_magnitude_bins
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: def forward(self, *inputs: Any) -> Any:
if not (isinstance(inpt, (features.Image, PIL.Image.Image)) or is_simple_tensor(inpt)): sample = inputs if len(inputs) > 1 else inputs[0]
return inpt
id, image = self._extract_image(sample)
num_channels, height, width = get_chw(image)
for _ in range(self.num_ops): for _ in range(self.num_ops):
transform_id, (magnitudes_fn, signed) = self._get_random_item(self._AUGMENTATION_SPACE) transform_id, (magnitudes_fn, signed) = self._get_random_item(self._AUGMENTATION_SPACE)
magnitudes = magnitudes_fn(self.num_magnitude_bins, params["height"], params["width"]) magnitudes = magnitudes_fn(self.num_magnitude_bins, height, width)
if magnitudes is not None: if magnitudes is not None:
magnitude = float(magnitudes[int(torch.randint(self.num_magnitude_bins, ()))]) magnitude = float(magnitudes[int(torch.randint(self.num_magnitude_bins, ()))])
if signed and torch.rand(()) <= 0.5: if signed and torch.rand(()) <= 0.5:
...@@ -330,11 +354,11 @@ class RandAugment(_AutoAugmentBase): ...@@ -330,11 +354,11 @@ class RandAugment(_AutoAugmentBase):
else: else:
magnitude = 0.0 magnitude = 0.0
inpt = self._apply_image_transform( image = self._apply_image_transform(
inpt, transform_id, magnitude, interpolation=self.interpolation, fill=self.fill image, transform_id, magnitude, interpolation=self.interpolation, fill=self.fill
) )
return inpt return self._put_into_sample(sample, id, image)
class TrivialAugmentWide(_AutoAugmentBase): class TrivialAugmentWide(_AutoAugmentBase):
...@@ -370,13 +394,15 @@ class TrivialAugmentWide(_AutoAugmentBase): ...@@ -370,13 +394,15 @@ class TrivialAugmentWide(_AutoAugmentBase):
super().__init__(interpolation=interpolation, fill=fill) super().__init__(interpolation=interpolation, fill=fill)
self.num_magnitude_bins = num_magnitude_bins self.num_magnitude_bins = num_magnitude_bins
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: def forward(self, *inputs: Any) -> Any:
if not (isinstance(inpt, (features.Image, PIL.Image.Image)) or is_simple_tensor(inpt)): sample = inputs if len(inputs) > 1 else inputs[0]
return inpt
id, image = self._extract_image(sample)
num_channels, height, width = get_chw(image)
transform_id, (magnitudes_fn, signed) = self._get_random_item(self._AUGMENTATION_SPACE) transform_id, (magnitudes_fn, signed) = self._get_random_item(self._AUGMENTATION_SPACE)
magnitudes = magnitudes_fn(self.num_magnitude_bins, params["height"], params["width"]) magnitudes = magnitudes_fn(self.num_magnitude_bins, height, width)
if magnitudes is not None: if magnitudes is not None:
magnitude = float(magnitudes[int(torch.randint(self.num_magnitude_bins, ()))]) magnitude = float(magnitudes[int(torch.randint(self.num_magnitude_bins, ()))])
if signed and torch.rand(()) <= 0.5: if signed and torch.rand(()) <= 0.5:
...@@ -384,9 +410,10 @@ class TrivialAugmentWide(_AutoAugmentBase): ...@@ -384,9 +410,10 @@ class TrivialAugmentWide(_AutoAugmentBase):
else: else:
magnitude = 0.0 magnitude = 0.0
return self._apply_image_transform( image = self._apply_image_transform(
inpt, transform_id, magnitude, interpolation=self.interpolation, fill=self.fill image, transform_id, magnitude, interpolation=self.interpolation, fill=self.fill
) )
return self._put_into_sample(sample, id, image)
class AugMix(_AutoAugmentBase): class AugMix(_AutoAugmentBase):
...@@ -438,13 +465,15 @@ class AugMix(_AutoAugmentBase): ...@@ -438,13 +465,15 @@ class AugMix(_AutoAugmentBase):
# Must be on a separate method so that we can overwrite it in tests. # Must be on a separate method so that we can overwrite it in tests.
return torch._sample_dirichlet(params) return torch._sample_dirichlet(params)
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: def forward(self, *inputs: Any) -> Any:
if isinstance(inpt, features.Image) or is_simple_tensor(inpt): sample = inputs if len(inputs) > 1 else inputs[0]
image = inpt id, orig_image = self._extract_image(sample)
elif isinstance(inpt, PIL.Image.Image): num_channels, height, width = get_chw(orig_image)
image = pil_to_tensor(inpt)
else: if isinstance(orig_image, torch.Tensor):
return inpt image = orig_image
else: # isinstance(inpt, PIL.Image.Image):
image = pil_to_tensor(orig_image)
augmentation_space = self._AUGMENTATION_SPACE if self.all_ops else self._PARTIAL_AUGMENTATION_SPACE augmentation_space = self._AUGMENTATION_SPACE if self.all_ops else self._PARTIAL_AUGMENTATION_SPACE
...@@ -470,7 +499,7 @@ class AugMix(_AutoAugmentBase): ...@@ -470,7 +499,7 @@ class AugMix(_AutoAugmentBase):
for _ in range(depth): for _ in range(depth):
transform_id, (magnitudes_fn, signed) = self._get_random_item(augmentation_space) transform_id, (magnitudes_fn, signed) = self._get_random_item(augmentation_space)
magnitudes = magnitudes_fn(self._PARAMETER_MAX, params["height"], params["width"]) magnitudes = magnitudes_fn(self._PARAMETER_MAX, height, width)
if magnitudes is not None: if magnitudes is not None:
magnitude = float(magnitudes[int(torch.randint(self.severity, ()))]) magnitude = float(magnitudes[int(torch.randint(self.severity, ()))])
if signed and torch.rand(()) <= 0.5: if signed and torch.rand(()) <= 0.5:
...@@ -484,9 +513,9 @@ class AugMix(_AutoAugmentBase): ...@@ -484,9 +513,9 @@ class AugMix(_AutoAugmentBase):
mix.add_(combined_weights[:, i].view(batch_dims) * aug) mix.add_(combined_weights[:, i].view(batch_dims) * aug)
mix = mix.view(orig_dims).to(dtype=image.dtype) mix = mix.view(orig_dims).to(dtype=image.dtype)
if isinstance(inpt, features.Image): if isinstance(orig_image, features.Image):
mix = features.Image.new_like(inpt, mix) mix = features.Image.new_like(orig_image, mix)
elif isinstance(inpt, PIL.Image.Image): elif isinstance(orig_image, PIL.Image.Image):
mix = to_pil_image(mix) mix = to_pil_image(mix)
return mix return self._put_into_sample(sample, id, mix)
...@@ -4,15 +4,14 @@ from typing import Any, Dict, Optional ...@@ -4,15 +4,14 @@ from typing import Any, Dict, Optional
import numpy as np import numpy as np
import PIL.Image import PIL.Image
import torch import torch
import torchvision.prototype.transforms.functional as F
from torchvision.prototype import features from torchvision.prototype import features
from torchvision.prototype.features import ColorSpace
from torchvision.prototype.transforms import Transform from torchvision.prototype.transforms import Transform
from torchvision.transforms import functional as _F from torchvision.transforms import functional as _F
from typing_extensions import Literal from typing_extensions import Literal
from ._transform import _RandomApplyTransform from ._transform import _RandomApplyTransform
from ._utils import is_simple_tensor from ._utils import is_simple_tensor, query_chw
class ToTensor(Transform): class ToTensor(Transform):
...@@ -59,6 +58,8 @@ class ToPILImage(Transform): ...@@ -59,6 +58,8 @@ class ToPILImage(Transform):
class Grayscale(Transform): class Grayscale(Transform):
_transformed_types = (features.Image, PIL.Image.Image, is_simple_tensor)
def __init__(self, num_output_channels: Literal[1, 3] = 1) -> None: def __init__(self, num_output_channels: Literal[1, 3] = 1) -> None:
deprecation_msg = ( deprecation_msg = (
f"The transform `Grayscale(num_output_channels={num_output_channels})` " f"The transform `Grayscale(num_output_channels={num_output_channels})` "
...@@ -81,13 +82,12 @@ class Grayscale(Transform): ...@@ -81,13 +82,12 @@ class Grayscale(Transform):
self.num_output_channels = num_output_channels self.num_output_channels = num_output_channels
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
output = F.convert_color_space(inpt, color_space=ColorSpace.GRAY, old_color_space=ColorSpace.RGB) return _F.rgb_to_grayscale(inpt, num_output_channels=self.num_output_channels)
if self.num_output_channels == 3:
output = F.convert_color_space(inpt, color_space=ColorSpace.RGB, old_color_space=ColorSpace.GRAY)
return output
class RandomGrayscale(_RandomApplyTransform): class RandomGrayscale(_RandomApplyTransform):
_transformed_types = (features.Image, PIL.Image.Image, is_simple_tensor)
def __init__(self, p: float = 0.1) -> None: def __init__(self, p: float = 0.1) -> None:
warnings.warn( warnings.warn(
"The transform `RandomGrayscale(p=...)` is deprecated and will be removed in a future release. " "The transform `RandomGrayscale(p=...)` is deprecated and will be removed in a future release. "
...@@ -103,6 +103,9 @@ class RandomGrayscale(_RandomApplyTransform): ...@@ -103,6 +103,9 @@ class RandomGrayscale(_RandomApplyTransform):
super().__init__(p=p) super().__init__(p=p)
def _get_params(self, sample: Any) -> Dict[str, Any]:
num_input_channels, _, _ = query_chw(sample)
return dict(num_input_channels=num_input_channels)
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
output = F.convert_color_space(inpt, color_space=ColorSpace.GRAY, old_color_space=ColorSpace.RGB) return _F.rgb_to_grayscale(inpt, num_output_channels=params["num_input_channels"])
return F.convert_color_space(output, color_space=ColorSpace.RGB, old_color_space=ColorSpace.GRAY)
...@@ -156,22 +156,14 @@ class FiveCrop(Transform): ...@@ -156,22 +156,14 @@ class FiveCrop(Transform):
torch.Size([5]) torch.Size([5])
""" """
_transformed_types = (features.Image, PIL.Image.Image, is_simple_tensor)
def __init__(self, size: Union[int, Sequence[int]]) -> None: def __init__(self, size: Union[int, Sequence[int]]) -> None:
super().__init__() super().__init__()
self.size = _setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.") self.size = _setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.")
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
# TODO: returning a list is technically BC breaking since FiveCrop returned a tuple before. We switched to a return F.five_crop(inpt, self.size)
# list here to align it with TenCrop.
if isinstance(inpt, features.Image):
output = F.five_crop_image_tensor(inpt, self.size)
return tuple(features.Image.new_like(inpt, o) for o in output)
elif is_simple_tensor(inpt):
return F.five_crop_image_tensor(inpt, self.size)
elif isinstance(inpt, PIL.Image.Image):
return F.five_crop_image_pil(inpt, self.size)
else:
return inpt
def forward(self, *inputs: Any) -> Any: def forward(self, *inputs: Any) -> Any:
sample = inputs if len(inputs) > 1 else inputs[0] sample = inputs if len(inputs) > 1 else inputs[0]
...@@ -185,21 +177,15 @@ class TenCrop(Transform): ...@@ -185,21 +177,15 @@ class TenCrop(Transform):
See :class:`~torchvision.prototype.transforms.FiveCrop` for an example. See :class:`~torchvision.prototype.transforms.FiveCrop` for an example.
""" """
_transformed_types = (features.Image, PIL.Image.Image, is_simple_tensor)
def __init__(self, size: Union[int, Sequence[int]], vertical_flip: bool = False) -> None: def __init__(self, size: Union[int, Sequence[int]], vertical_flip: bool = False) -> None:
super().__init__() super().__init__()
self.size = _setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.") self.size = _setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.")
self.vertical_flip = vertical_flip self.vertical_flip = vertical_flip
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
if isinstance(inpt, features.Image): return F.ten_crop(inpt, self.size, vertical_flip=self.vertical_flip)
output = F.ten_crop_image_tensor(inpt, self.size, vertical_flip=self.vertical_flip)
return [features.Image.new_like(inpt, o) for o in output]
elif is_simple_tensor(inpt):
return F.ten_crop_image_tensor(inpt, self.size, vertical_flip=self.vertical_flip)
elif isinstance(inpt, PIL.Image.Image):
return F.ten_crop_image_pil(inpt, self.size, vertical_flip=self.vertical_flip)
else:
return inpt
def forward(self, *inputs: Any) -> Any: def forward(self, *inputs: Any) -> Any:
sample = inputs if len(inputs) > 1 else inputs[0] sample = inputs if len(inputs) > 1 else inputs[0]
......
...@@ -22,7 +22,7 @@ class Lambda(Transform): ...@@ -22,7 +22,7 @@ class Lambda(Transform):
def __init__(self, fn: Callable[[Any], Any], *types: Type): def __init__(self, fn: Callable[[Any], Any], *types: Type):
super().__init__() super().__init__()
self.fn = fn self.fn = fn
self.types = types self.types = types or (object,)
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
if type(inpt) in self.types: if type(inpt) in self.types:
...@@ -137,7 +137,7 @@ class GaussianBlur(Transform): ...@@ -137,7 +137,7 @@ class GaussianBlur(Transform):
class ToDtype(Lambda): class ToDtype(Lambda):
def __init__(self, dtype: torch.dtype, *types: Type) -> None: def __init__(self, dtype: torch.dtype, *types: Type) -> None:
self.dtype = dtype self.dtype = dtype
super().__init__(functools.partial(torch.Tensor.to, dtype=dtype), *types) super().__init__(functools.partial(torch.Tensor.to, dtype=dtype), *types or (torch.Tensor,))
def extra_repr(self) -> str: def extra_repr(self) -> str:
return ", ".join([f"dtype={self.dtype}", f"types={[type.__name__ for type in self.types]}"]) return ", ".join([f"dtype={self.dtype}", f"types={[type.__name__ for type in self.types]}"])
......
...@@ -65,6 +65,7 @@ from ._geometry import ( ...@@ -65,6 +65,7 @@ from ._geometry import (
elastic_image_tensor, elastic_image_tensor,
elastic_segmentation_mask, elastic_segmentation_mask,
elastic_transform, elastic_transform,
five_crop,
five_crop_image_pil, five_crop_image_pil,
five_crop_image_tensor, five_crop_image_tensor,
horizontal_flip, horizontal_flip,
...@@ -97,6 +98,7 @@ from ._geometry import ( ...@@ -97,6 +98,7 @@ from ._geometry import (
rotate_image_pil, rotate_image_pil,
rotate_image_tensor, rotate_image_tensor,
rotate_segmentation_mask, rotate_segmentation_mask,
ten_crop,
ten_crop_image_pil, ten_crop_image_pil,
ten_crop_image_tensor, ten_crop_image_tensor,
vertical_flip, vertical_flip,
......
...@@ -1078,6 +1078,17 @@ def five_crop_image_pil( ...@@ -1078,6 +1078,17 @@ def five_crop_image_pil(
return tl, tr, bl, br, center return tl, tr, bl, br, center
def five_crop(inpt: DType, size: List[int]) -> Tuple[DType, DType, DType, DType, DType]:
# TODO: consider breaking BC here to return List[DType] to align this op with `ten_crop`
if isinstance(inpt, torch.Tensor):
output = five_crop_image_tensor(inpt, size)
if isinstance(inpt, features.Image):
output = tuple(features.Image.new_like(inpt, item) for item in output) # type: ignore[assignment]
return output
else: # isinstance(inpt, PIL.Image.Image):
return five_crop_image_pil(inpt, size)
def ten_crop_image_tensor(img: torch.Tensor, size: List[int], vertical_flip: bool = False) -> List[torch.Tensor]: def ten_crop_image_tensor(img: torch.Tensor, size: List[int], vertical_flip: bool = False) -> List[torch.Tensor]:
tl, tr, bl, br, center = five_crop_image_tensor(img, size) tl, tr, bl, br, center = five_crop_image_tensor(img, size)
...@@ -1102,3 +1113,13 @@ def ten_crop_image_pil(img: PIL.Image.Image, size: List[int], vertical_flip: boo ...@@ -1102,3 +1113,13 @@ def ten_crop_image_pil(img: PIL.Image.Image, size: List[int], vertical_flip: boo
tl_flip, tr_flip, bl_flip, br_flip, center_flip = five_crop_image_pil(img, size) tl_flip, tr_flip, bl_flip, br_flip, center_flip = five_crop_image_pil(img, size)
return [tl, tr, bl, br, center, tl_flip, tr_flip, bl_flip, br_flip, center_flip] return [tl, tr, bl, br, center, tl_flip, tr_flip, bl_flip, br_flip, center_flip]
def ten_crop(inpt: DType, size: List[int], *, vertical_flip: bool = False) -> List[DType]:
if isinstance(inpt, torch.Tensor):
output = ten_crop_image_tensor(inpt, size, vertical_flip=vertical_flip)
if isinstance(inpt, features.Image):
output = [features.Image.new_like(inpt, item) for item in output]
return output
else: # isinstance(inpt, PIL.Image.Image):
return ten_crop_image_pil(inpt, size, vertical_flip=vertical_flip)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment