Unverified Commit 64e7460e authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

Prototype transforms cleanup (#5504)



* fix grayscale to RGB for batches

* make unsupported types in auto augment a parameter

* make auto augment kwargs explicit

* add missing error message

* add support for specifying probabilites on RandomChoice

* remove TODO for deprecating p on random transforms

* streamline sample type checking

* address comments

* split image_size into height and width in auto augment
Co-authored-by: default avatarVasilis Vryniotis <datumbox@users.noreply.github.com>
parent b8de0b84
...@@ -7,7 +7,7 @@ import torch ...@@ -7,7 +7,7 @@ import torch
from torchvision.prototype import features from torchvision.prototype import features
from torchvision.prototype.transforms import Transform, functional as F from torchvision.prototype.transforms import Transform, functional as F
from ._utils import query_image, get_image_dimensions from ._utils import query_image, get_image_dimensions, has_all, has_any
class RandomErasing(Transform): class RandomErasing(Transform):
...@@ -33,7 +33,6 @@ class RandomErasing(Transform): ...@@ -33,7 +33,6 @@ class RandomErasing(Transform):
raise ValueError("Scale should be between 0 and 1") raise ValueError("Scale should be between 0 and 1")
if p < 0 or p > 1: if p < 0 or p > 1:
raise ValueError("Random erasing probability should be between 0 and 1") raise ValueError("Random erasing probability should be between 0 and 1")
# TODO: deprecate p in favor of wrapping the transform in a RandomApply
self.p = p self.p = p
self.scale = scale self.scale = scale
self.ratio = ratio self.ratio = ratio
...@@ -88,9 +87,7 @@ class RandomErasing(Transform): ...@@ -88,9 +87,7 @@ class RandomErasing(Transform):
return dict(zip("ijhwv", (i, j, h, w, v))) return dict(zip("ijhwv", (i, j, h, w, v)))
def _transform(self, input: Any, params: Dict[str, Any]) -> Any: def _transform(self, input: Any, params: Dict[str, Any]) -> Any:
if isinstance(input, (features.BoundingBox, features.SegmentationMask)): if isinstance(input, features.Image):
raise TypeError(f"{type(input).__name__}'s are not supported by {type(self).__name__}()")
elif isinstance(input, features.Image):
output = F.erase_image_tensor(input, **params) output = F.erase_image_tensor(input, **params)
return features.Image.new_like(input, output) return features.Image.new_like(input, output)
elif isinstance(input, torch.Tensor): elif isinstance(input, torch.Tensor):
...@@ -99,10 +96,13 @@ class RandomErasing(Transform): ...@@ -99,10 +96,13 @@ class RandomErasing(Transform):
return input return input
def forward(self, *inputs: Any) -> Any: def forward(self, *inputs: Any) -> Any:
if torch.rand(1) >= self.p: sample = inputs if len(inputs) > 1 else inputs[0]
return inputs if len(inputs) > 1 else inputs[0] if has_any(sample, features.BoundingBox, features.SegmentationMask):
raise TypeError(f"BoundingBox'es and SegmentationMask's are not supported by {type(self).__name__}()")
elif torch.rand(1) >= self.p:
return sample
return super().forward(*inputs) return super().forward(sample)
class RandomMixup(Transform): class RandomMixup(Transform):
...@@ -115,9 +115,7 @@ class RandomMixup(Transform): ...@@ -115,9 +115,7 @@ class RandomMixup(Transform):
return dict(lam=float(self._dist.sample(()))) return dict(lam=float(self._dist.sample(())))
def _transform(self, input: Any, params: Dict[str, Any]) -> Any: def _transform(self, input: Any, params: Dict[str, Any]) -> Any:
if isinstance(input, (features.BoundingBox, features.SegmentationMask)): if isinstance(input, features.Image):
raise TypeError(f"{type(input).__name__}'s are not supported by {type(self).__name__}()")
elif isinstance(input, features.Image):
output = F.mixup_image_tensor(input, **params) output = F.mixup_image_tensor(input, **params)
return features.Image.new_like(input, output) return features.Image.new_like(input, output)
elif isinstance(input, features.OneHotLabel): elif isinstance(input, features.OneHotLabel):
...@@ -126,6 +124,14 @@ class RandomMixup(Transform): ...@@ -126,6 +124,14 @@ class RandomMixup(Transform):
else: else:
return input return input
def forward(self, *inputs: Any) -> Any:
sample = inputs if len(inputs) > 1 else inputs[0]
if has_any(sample, features.BoundingBox, features.SegmentationMask):
raise TypeError(f"BoundingBox'es and SegmentationMask's are not supported by {type(self).__name__}()")
elif not has_all(sample, features.Image, features.OneHotLabel):
raise TypeError(f"{type(self).__name__}() is only defined for Image's *and* OneHotLabel's.")
return super().forward(sample)
class RandomCutmix(Transform): class RandomCutmix(Transform):
def __init__(self, *, alpha: float) -> None: def __init__(self, *, alpha: float) -> None:
...@@ -157,9 +163,7 @@ class RandomCutmix(Transform): ...@@ -157,9 +163,7 @@ class RandomCutmix(Transform):
return dict(box=box, lam_adjusted=lam_adjusted) return dict(box=box, lam_adjusted=lam_adjusted)
def _transform(self, input: Any, params: Dict[str, Any]) -> Any: def _transform(self, input: Any, params: Dict[str, Any]) -> Any:
if isinstance(input, (features.BoundingBox, features.SegmentationMask)): if isinstance(input, features.Image):
raise TypeError(f"{type(input).__name__}'s are not supported by {type(self).__name__}()")
elif isinstance(input, features.Image):
output = F.cutmix_image_tensor(input, box=params["box"]) output = F.cutmix_image_tensor(input, box=params["box"])
return features.Image.new_like(input, output) return features.Image.new_like(input, output)
elif isinstance(input, features.OneHotLabel): elif isinstance(input, features.OneHotLabel):
...@@ -167,3 +171,11 @@ class RandomCutmix(Transform): ...@@ -167,3 +171,11 @@ class RandomCutmix(Transform):
return features.OneHotLabel.new_like(input, output) return features.OneHotLabel.new_like(input, output)
else: else:
return input return input
def forward(self, *inputs: Any) -> Any:
sample = inputs if len(inputs) > 1 else inputs[0]
if has_any(sample, features.BoundingBox, features.SegmentationMask):
raise TypeError(f"BoundingBox'es and SegmentationMask's are not supported by {type(self).__name__}()")
elif not has_all(sample, features.Image, features.OneHotLabel):
raise TypeError(f"{type(self).__name__}() is only defined for Image's *and* OneHotLabel's.")
return super().forward(sample)
import math import math
from typing import Any, Dict, Tuple, Optional, Callable, List, cast, TypeVar, Union from typing import Any, Dict, Tuple, Optional, Callable, List, cast, TypeVar, Union, Type
import PIL.Image import PIL.Image
import torch import torch
...@@ -39,21 +39,20 @@ class _AutoAugmentBase(Transform): ...@@ -39,21 +39,20 @@ class _AutoAugmentBase(Transform):
key = keys[int(torch.randint(len(keys), ()))] key = keys[int(torch.randint(len(keys), ()))]
return key, dct[key] return key, dct[key]
def _check_unsupported(self, input: Any) -> None:
if isinstance(input, (features.BoundingBox, features.SegmentationMask)):
raise TypeError(f"{type(input).__name__}'s are not supported by {type(self).__name__}()")
def _extract_image( def _extract_image(
self, sample: Any self,
sample: Any,
unsupported_types: Tuple[Type, ...] = (features.BoundingBox, features.SegmentationMask),
) -> Tuple[Tuple[Any, ...], Union[PIL.Image.Image, torch.Tensor, features.Image]]: ) -> Tuple[Tuple[Any, ...], Union[PIL.Image.Image, torch.Tensor, features.Image]]:
def fn( def fn(
id: Tuple[Any, ...], input: Any id: Tuple[Any, ...], input: Any
) -> Optional[Tuple[Tuple[Any, ...], Union[PIL.Image.Image, torch.Tensor, features.Image]]]: ) -> Optional[Tuple[Tuple[Any, ...], Union[PIL.Image.Image, torch.Tensor, features.Image]]]:
if type(input) in {torch.Tensor, features.Image} or isinstance(input, PIL.Image.Image): if type(input) in {torch.Tensor, features.Image} or isinstance(input, PIL.Image.Image):
return id, input return id, input
elif isinstance(input, unsupported_types):
self._check_unsupported(input) raise TypeError(f"Inputs of type {type(input).__name__} are not supported by {type(self).__name__}()")
return None else:
return None
images = list(query_recursively(fn, sample)) images = list(query_recursively(fn, sample))
if not images: if not images:
...@@ -200,29 +199,40 @@ class _AutoAugmentBase(Transform): ...@@ -200,29 +199,40 @@ class _AutoAugmentBase(Transform):
class AutoAugment(_AutoAugmentBase): class AutoAugment(_AutoAugmentBase):
_AUGMENTATION_SPACE = { _AUGMENTATION_SPACE = {
"ShearX": (lambda num_bins, image_size: torch.linspace(0.0, 0.3, num_bins), True), "ShearX": (lambda num_bins, height, width: torch.linspace(0.0, 0.3, num_bins), True),
"ShearY": (lambda num_bins, image_size: torch.linspace(0.0, 0.3, num_bins), True), "ShearY": (lambda num_bins, height, width: torch.linspace(0.0, 0.3, num_bins), True),
"TranslateX": (lambda num_bins, image_size: torch.linspace(0.0, 150.0 / 331.0 * image_size[1], num_bins), True), "TranslateX": (
"TranslateY": (lambda num_bins, image_size: torch.linspace(0.0, 150.0 / 331.0 * image_size[0], num_bins), True), lambda num_bins, height, width: torch.linspace(0.0, 150.0 / 331.0 * width, num_bins),
"Rotate": (lambda num_bins, image_size: torch.linspace(0.0, 30.0, num_bins), True), True,
"Brightness": (lambda num_bins, image_size: torch.linspace(0.0, 0.9, num_bins), True), ),
"Color": (lambda num_bins, image_size: torch.linspace(0.0, 0.9, num_bins), True), "TranslateY": (
"Contrast": (lambda num_bins, image_size: torch.linspace(0.0, 0.9, num_bins), True), lambda num_bins, height, width: torch.linspace(0.0, 150.0 / 331.0 * height, num_bins),
"Sharpness": (lambda num_bins, image_size: torch.linspace(0.0, 0.9, num_bins), True), True,
),
"Rotate": (lambda num_bins, height, width: torch.linspace(0.0, 30.0, num_bins), True),
"Brightness": (lambda num_bins, height, width: torch.linspace(0.0, 0.9, num_bins), True),
"Color": (lambda num_bins, height, width: torch.linspace(0.0, 0.9, num_bins), True),
"Contrast": (lambda num_bins, height, width: torch.linspace(0.0, 0.9, num_bins), True),
"Sharpness": (lambda num_bins, height, width: torch.linspace(0.0, 0.9, num_bins), True),
"Posterize": ( "Posterize": (
lambda num_bins, image_size: cast(torch.Tensor, 8 - (torch.arange(num_bins) / ((num_bins - 1) / 4))) lambda num_bins, height, width: cast(torch.Tensor, 8 - (torch.arange(num_bins) / ((num_bins - 1) / 4)))
.round() .round()
.int(), .int(),
False, False,
), ),
"Solarize": (lambda num_bins, image_size: torch.linspace(255.0, 0.0, num_bins), False), "Solarize": (lambda num_bins, height, width: torch.linspace(255.0, 0.0, num_bins), False),
"AutoContrast": (lambda num_bins, image_size: None, False), "AutoContrast": (lambda num_bins, height, width: None, False),
"Equalize": (lambda num_bins, image_size: None, False), "Equalize": (lambda num_bins, height, width: None, False),
"Invert": (lambda num_bins, image_size: None, False), "Invert": (lambda num_bins, height, width: None, False),
} }
def __init__(self, policy: AutoAugmentPolicy = AutoAugmentPolicy.IMAGENET, **kwargs: Any) -> None: def __init__(
super().__init__(**kwargs) self,
policy: AutoAugmentPolicy = AutoAugmentPolicy.IMAGENET,
interpolation: InterpolationMode = InterpolationMode.NEAREST,
fill: Optional[List[float]] = None,
) -> None:
super().__init__(interpolation=interpolation, fill=fill)
self.policy = policy self.policy = policy
self._policies = self._get_policies(policy) self._policies = self._get_policies(policy)
...@@ -331,7 +341,7 @@ class AutoAugment(_AutoAugmentBase): ...@@ -331,7 +341,7 @@ class AutoAugment(_AutoAugmentBase):
magnitudes_fn, signed = self._AUGMENTATION_SPACE[transform_id] magnitudes_fn, signed = self._AUGMENTATION_SPACE[transform_id]
magnitudes = magnitudes_fn(10, (height, width)) magnitudes = magnitudes_fn(10, height, width)
if magnitudes is not None: if magnitudes is not None:
magnitude = float(magnitudes[magnitude_idx]) magnitude = float(magnitudes[magnitude_idx])
if signed and torch.rand(()) <= 0.5: if signed and torch.rand(()) <= 0.5:
...@@ -348,29 +358,43 @@ class AutoAugment(_AutoAugmentBase): ...@@ -348,29 +358,43 @@ class AutoAugment(_AutoAugmentBase):
class RandAugment(_AutoAugmentBase): class RandAugment(_AutoAugmentBase):
_AUGMENTATION_SPACE = { _AUGMENTATION_SPACE = {
"Identity": (lambda num_bins, image_size: None, False), "Identity": (lambda num_bins, height, width: None, False),
"ShearX": (lambda num_bins, image_size: torch.linspace(0.0, 0.3, num_bins), True), "ShearX": (lambda num_bins, height, width: torch.linspace(0.0, 0.3, num_bins), True),
"ShearY": (lambda num_bins, image_size: torch.linspace(0.0, 0.3, num_bins), True), "ShearY": (lambda num_bins, height, width: torch.linspace(0.0, 0.3, num_bins), True),
"TranslateX": (lambda num_bins, image_size: torch.linspace(0.0, 150.0 / 331.0 * image_size[1], num_bins), True), "TranslateX": (
"TranslateY": (lambda num_bins, image_size: torch.linspace(0.0, 150.0 / 331.0 * image_size[0], num_bins), True), lambda num_bins, height, width: torch.linspace(0.0, 150.0 / 331.0 * width, num_bins),
"Rotate": (lambda num_bins, image_size: torch.linspace(0.0, 30.0, num_bins), True), True,
"Brightness": (lambda num_bins, image_size: torch.linspace(0.0, 0.9, num_bins), True), ),
"Color": (lambda num_bins, image_size: torch.linspace(0.0, 0.9, num_bins), True), "TranslateY": (
"Contrast": (lambda num_bins, image_size: torch.linspace(0.0, 0.9, num_bins), True), lambda num_bins, height, width: torch.linspace(0.0, 150.0 / 331.0 * height, num_bins),
"Sharpness": (lambda num_bins, image_size: torch.linspace(0.0, 0.9, num_bins), True), True,
),
"Rotate": (lambda num_bins, height, width: torch.linspace(0.0, 30.0, num_bins), True),
"Brightness": (lambda num_bins, height, width: torch.linspace(0.0, 0.9, num_bins), True),
"Color": (lambda num_bins, height, width: torch.linspace(0.0, 0.9, num_bins), True),
"Contrast": (lambda num_bins, height, width: torch.linspace(0.0, 0.9, num_bins), True),
"Sharpness": (lambda num_bins, height, width: torch.linspace(0.0, 0.9, num_bins), True),
"Posterize": ( "Posterize": (
lambda num_bins, image_size: cast(torch.Tensor, 8 - (torch.arange(num_bins) / ((num_bins - 1) / 4))) lambda num_bins, height, width: cast(torch.Tensor, 8 - (torch.arange(num_bins) / ((num_bins - 1) / 4)))
.round() .round()
.int(), .int(),
False, False,
), ),
"Solarize": (lambda num_bins, image_size: torch.linspace(255.0, 0.0, num_bins), False), "Solarize": (lambda num_bins, height, width: torch.linspace(255.0, 0.0, num_bins), False),
"AutoContrast": (lambda num_bins, image_size: None, False), "AutoContrast": (lambda num_bins, height, width: None, False),
"Equalize": (lambda num_bins, image_size: None, False), "Equalize": (lambda num_bins, height, width: None, False),
} }
def __init__(self, *, num_ops: int = 2, magnitude: int = 9, num_magnitude_bins: int = 31, **kwargs: Any) -> None: def __init__(
super().__init__(**kwargs) self,
*,
num_ops: int = 2,
magnitude: int = 9,
num_magnitude_bins: int = 31,
interpolation: InterpolationMode = InterpolationMode.NEAREST,
fill: Optional[List[float]] = None,
) -> None:
super().__init__(interpolation=interpolation, fill=fill)
self.num_ops = num_ops self.num_ops = num_ops
self.magnitude = magnitude self.magnitude = magnitude
self.num_magnitude_bins = num_magnitude_bins self.num_magnitude_bins = num_magnitude_bins
...@@ -385,7 +409,7 @@ class RandAugment(_AutoAugmentBase): ...@@ -385,7 +409,7 @@ class RandAugment(_AutoAugmentBase):
for _ in range(self.num_ops): for _ in range(self.num_ops):
transform_id, (magnitudes_fn, signed) = self._get_random_item(self._AUGMENTATION_SPACE) transform_id, (magnitudes_fn, signed) = self._get_random_item(self._AUGMENTATION_SPACE)
magnitudes = magnitudes_fn(self.num_magnitude_bins, (height, width)) magnitudes = magnitudes_fn(self.num_magnitude_bins, height, width)
if magnitudes is not None: if magnitudes is not None:
magnitude = float(magnitudes[int(torch.randint(self.num_magnitude_bins, ()))]) magnitude = float(magnitudes[int(torch.randint(self.num_magnitude_bins, ()))])
if signed and torch.rand(()) <= 0.5: if signed and torch.rand(()) <= 0.5:
...@@ -402,29 +426,35 @@ class RandAugment(_AutoAugmentBase): ...@@ -402,29 +426,35 @@ class RandAugment(_AutoAugmentBase):
class TrivialAugmentWide(_AutoAugmentBase): class TrivialAugmentWide(_AutoAugmentBase):
_AUGMENTATION_SPACE = { _AUGMENTATION_SPACE = {
"Identity": (lambda num_bins, image_size: None, False), "Identity": (lambda num_bins, height, width: None, False),
"ShearX": (lambda num_bins, image_size: torch.linspace(0.0, 0.99, num_bins), True), "ShearX": (lambda num_bins, height, width: torch.linspace(0.0, 0.99, num_bins), True),
"ShearY": (lambda num_bins, image_size: torch.linspace(0.0, 0.99, num_bins), True), "ShearY": (lambda num_bins, height, width: torch.linspace(0.0, 0.99, num_bins), True),
"TranslateX": (lambda num_bins, image_size: torch.linspace(0.0, 32.0, num_bins), True), "TranslateX": (lambda num_bins, height, width: torch.linspace(0.0, 32.0, num_bins), True),
"TranslateY": (lambda num_bins, image_size: torch.linspace(0.0, 32.0, num_bins), True), "TranslateY": (lambda num_bins, height, width: torch.linspace(0.0, 32.0, num_bins), True),
"Rotate": (lambda num_bins, image_size: torch.linspace(0.0, 135.0, num_bins), True), "Rotate": (lambda num_bins, height, width: torch.linspace(0.0, 135.0, num_bins), True),
"Brightness": (lambda num_bins, image_size: torch.linspace(0.0, 0.99, num_bins), True), "Brightness": (lambda num_bins, height, width: torch.linspace(0.0, 0.99, num_bins), True),
"Color": (lambda num_bins, image_size: torch.linspace(0.0, 0.99, num_bins), True), "Color": (lambda num_bins, height, width: torch.linspace(0.0, 0.99, num_bins), True),
"Contrast": (lambda num_bins, image_size: torch.linspace(0.0, 0.99, num_bins), True), "Contrast": (lambda num_bins, height, width: torch.linspace(0.0, 0.99, num_bins), True),
"Sharpness": (lambda num_bins, image_size: torch.linspace(0.0, 0.99, num_bins), True), "Sharpness": (lambda num_bins, height, width: torch.linspace(0.0, 0.99, num_bins), True),
"Posterize": ( "Posterize": (
lambda num_bins, image_size: cast(torch.Tensor, 8 - (torch.arange(num_bins) / ((num_bins - 1) / 6))) lambda num_bins, height, width: cast(torch.Tensor, 8 - (torch.arange(num_bins) / ((num_bins - 1) / 6)))
.round() .round()
.int(), .int(),
False, False,
), ),
"Solarize": (lambda num_bins, image_size: torch.linspace(255.0, 0.0, num_bins), False), "Solarize": (lambda num_bins, height, width: torch.linspace(255.0, 0.0, num_bins), False),
"AutoContrast": (lambda num_bins, image_size: None, False), "AutoContrast": (lambda num_bins, height, width: None, False),
"Equalize": (lambda num_bins, image_size: None, False), "Equalize": (lambda num_bins, height, width: None, False),
} }
def __init__(self, *, num_magnitude_bins: int = 31, **kwargs: Any): def __init__(
super().__init__(**kwargs) self,
*,
num_magnitude_bins: int = 31,
interpolation: InterpolationMode = InterpolationMode.NEAREST,
fill: Optional[List[float]] = None,
):
super().__init__(interpolation=interpolation, fill=fill)
self.num_magnitude_bins = num_magnitude_bins self.num_magnitude_bins = num_magnitude_bins
def forward(self, *inputs: Any) -> Any: def forward(self, *inputs: Any) -> Any:
...@@ -436,7 +466,7 @@ class TrivialAugmentWide(_AutoAugmentBase): ...@@ -436,7 +466,7 @@ class TrivialAugmentWide(_AutoAugmentBase):
transform_id, (magnitudes_fn, signed) = self._get_random_item(self._AUGMENTATION_SPACE) transform_id, (magnitudes_fn, signed) = self._get_random_item(self._AUGMENTATION_SPACE)
magnitudes = magnitudes_fn(self.num_magnitude_bins, (height, width)) magnitudes = magnitudes_fn(self.num_magnitude_bins, height, width)
if magnitudes is not None: if magnitudes is not None:
magnitude = float(magnitudes[int(torch.randint(self.num_magnitude_bins, ()))]) magnitude = float(magnitudes[int(torch.randint(self.num_magnitude_bins, ()))])
if signed and torch.rand(()) <= 0.5: if signed and torch.rand(()) <= 0.5:
...@@ -450,27 +480,27 @@ class TrivialAugmentWide(_AutoAugmentBase): ...@@ -450,27 +480,27 @@ class TrivialAugmentWide(_AutoAugmentBase):
class AugMix(_AutoAugmentBase): class AugMix(_AutoAugmentBase):
_PARTIAL_AUGMENTATION_SPACE = { _PARTIAL_AUGMENTATION_SPACE = {
"ShearX": (lambda num_bins, image_size: torch.linspace(0.0, 0.3, num_bins), True), "ShearX": (lambda num_bins, height, width: torch.linspace(0.0, 0.3, num_bins), True),
"ShearY": (lambda num_bins, image_size: torch.linspace(0.0, 0.3, num_bins), True), "ShearY": (lambda num_bins, height, width: torch.linspace(0.0, 0.3, num_bins), True),
"TranslateX": (lambda num_bins, image_size: torch.linspace(0.0, image_size[1] / 3.0, num_bins), True), "TranslateX": (lambda num_bins, height, width: torch.linspace(0.0, width / 3.0, num_bins), True),
"TranslateY": (lambda num_bins, image_size: torch.linspace(0.0, image_size[0] / 3.0, num_bins), True), "TranslateY": (lambda num_bins, height, width: torch.linspace(0.0, height / 3.0, num_bins), True),
"Rotate": (lambda num_bins, image_size: torch.linspace(0.0, 30.0, num_bins), True), "Rotate": (lambda num_bins, height, width: torch.linspace(0.0, 30.0, num_bins), True),
"Posterize": ( "Posterize": (
lambda num_bins, image_size: cast(torch.Tensor, 4 - (torch.arange(num_bins) / ((num_bins - 1) / 4))) lambda num_bins, height, width: cast(torch.Tensor, 4 - (torch.arange(num_bins) / ((num_bins - 1) / 4)))
.round() .round()
.int(), .int(),
False, False,
), ),
"Solarize": (lambda num_bins, image_size: torch.linspace(255.0, 0.0, num_bins), False), "Solarize": (lambda num_bins, height, width: torch.linspace(255.0, 0.0, num_bins), False),
"AutoContrast": (lambda num_bins, image_size: None, False), "AutoContrast": (lambda num_bins, height, width: None, False),
"Equalize": (lambda num_bins, image_size: None, False), "Equalize": (lambda num_bins, height, width: None, False),
} }
_AUGMENTATION_SPACE: Dict[str, Tuple[Callable[[int, Tuple[int, int]], Optional[torch.Tensor]], bool]] = { _AUGMENTATION_SPACE: Dict[str, Tuple[Callable[[int, int, int], Optional[torch.Tensor]], bool]] = {
**_PARTIAL_AUGMENTATION_SPACE, **_PARTIAL_AUGMENTATION_SPACE,
"Brightness": (lambda num_bins, image_size: torch.linspace(0.0, 0.9, num_bins), True), "Brightness": (lambda num_bins, height, width: torch.linspace(0.0, 0.9, num_bins), True),
"Color": (lambda num_bins, image_size: torch.linspace(0.0, 0.9, num_bins), True), "Color": (lambda num_bins, height, width: torch.linspace(0.0, 0.9, num_bins), True),
"Contrast": (lambda num_bins, image_size: torch.linspace(0.0, 0.9, num_bins), True), "Contrast": (lambda num_bins, height, width: torch.linspace(0.0, 0.9, num_bins), True),
"Sharpness": (lambda num_bins, image_size: torch.linspace(0.0, 0.9, num_bins), True), "Sharpness": (lambda num_bins, height, width: torch.linspace(0.0, 0.9, num_bins), True),
} }
def __init__( def __init__(
...@@ -480,9 +510,10 @@ class AugMix(_AutoAugmentBase): ...@@ -480,9 +510,10 @@ class AugMix(_AutoAugmentBase):
chain_depth: int = -1, chain_depth: int = -1,
alpha: float = 1.0, alpha: float = 1.0,
all_ops: bool = True, all_ops: bool = True,
**kwargs: Any, interpolation: InterpolationMode = InterpolationMode.BILINEAR,
fill: Optional[List[float]] = None,
) -> None: ) -> None:
super().__init__(**kwargs) super().__init__(interpolation=interpolation, fill=fill)
self._PARAMETER_MAX = 10 self._PARAMETER_MAX = 10
if not (1 <= severity <= self._PARAMETER_MAX): if not (1 <= severity <= self._PARAMETER_MAX):
raise ValueError(f"The severity must be between [1, {self._PARAMETER_MAX}]. Got {severity} instead.") raise ValueError(f"The severity must be between [1, {self._PARAMETER_MAX}]. Got {severity} instead.")
...@@ -531,7 +562,7 @@ class AugMix(_AutoAugmentBase): ...@@ -531,7 +562,7 @@ class AugMix(_AutoAugmentBase):
for _ in range(depth): for _ in range(depth):
transform_id, (magnitudes_fn, signed) = self._get_random_item(augmentation_space) transform_id, (magnitudes_fn, signed) = self._get_random_item(augmentation_space)
magnitudes = magnitudes_fn(self._PARAMETER_MAX, (height, width)) magnitudes = magnitudes_fn(self._PARAMETER_MAX, height, width)
if magnitudes is not None: if magnitudes is not None:
magnitude = float(magnitudes[int(torch.randint(self.severity, ()))]) magnitude = float(magnitudes[int(torch.randint(self.severity, ()))])
if signed and torch.rand(()) <= 0.5: if signed and torch.rand(()) <= 0.5:
......
from typing import Any from typing import Any, Optional, List
import torch import torch
...@@ -37,14 +37,26 @@ class RandomApply(Transform): ...@@ -37,14 +37,26 @@ class RandomApply(Transform):
class RandomChoice(Transform): class RandomChoice(Transform):
def __init__(self, *transforms: Transform) -> None: def __init__(self, *transforms: Transform, probabilities: Optional[List[float]] = None) -> None:
if probabilities is None:
probabilities = [1] * len(transforms)
elif len(probabilities) != len(transforms):
raise ValueError(
f"The number of probabilities doesn't match the number of transforms: "
f"{len(probabilities)} != {len(transforms)}"
)
super().__init__() super().__init__()
self.transforms = transforms self.transforms = transforms
for idx, transform in enumerate(transforms): for idx, transform in enumerate(transforms):
self.add_module(str(idx), transform) self.add_module(str(idx), transform)
total = sum(probabilities)
self.probabilities = [p / total for p in probabilities]
def forward(self, *inputs: Any) -> Any: def forward(self, *inputs: Any) -> Any:
idx = int(torch.randint(len(self.transforms), size=())) idx = int(torch.multinomial(torch.tensor(self.probabilities), 1))
transform = self.transforms[idx] transform = self.transforms[idx]
return transform(*inputs) return transform(*inputs)
......
...@@ -8,7 +8,7 @@ from torchvision.prototype import features ...@@ -8,7 +8,7 @@ from torchvision.prototype import features
from torchvision.prototype.transforms import Transform, InterpolationMode, functional as F from torchvision.prototype.transforms import Transform, InterpolationMode, functional as F
from torchvision.transforms.transforms import _setup_size, _interpolation_modes_from_int from torchvision.transforms.transforms import _setup_size, _interpolation_modes_from_int
from ._utils import query_image, get_image_dimensions from ._utils import query_image, get_image_dimensions, has_any
class HorizontalFlip(Transform): class HorizontalFlip(Transform):
...@@ -61,9 +61,7 @@ class CenterCrop(Transform): ...@@ -61,9 +61,7 @@ class CenterCrop(Transform):
self.output_size = output_size self.output_size = output_size
def _transform(self, input: Any, params: Dict[str, Any]) -> Any: def _transform(self, input: Any, params: Dict[str, Any]) -> Any:
if isinstance(input, (features.BoundingBox, features.SegmentationMask)): if isinstance(input, features.Image):
raise TypeError(f"{type(input).__name__}'s are not supported by {type(self).__name__}()")
elif isinstance(input, features.Image):
output = F.center_crop_image_tensor(input, self.output_size) output = F.center_crop_image_tensor(input, self.output_size)
return features.Image.new_like(input, output) return features.Image.new_like(input, output)
elif isinstance(input, torch.Tensor): elif isinstance(input, torch.Tensor):
...@@ -73,6 +71,12 @@ class CenterCrop(Transform): ...@@ -73,6 +71,12 @@ class CenterCrop(Transform):
else: else:
return input return input
def forward(self, *inputs: Any) -> Any:
sample = inputs if len(inputs) > 1 else inputs[0]
if has_any(sample, features.BoundingBox, features.SegmentationMask):
raise TypeError(f"BoundingBox'es and SegmentationMask's are not supported by {type(self).__name__}()")
return super().forward(sample)
class RandomResizedCrop(Transform): class RandomResizedCrop(Transform):
def __init__( def __init__(
...@@ -147,9 +151,7 @@ class RandomResizedCrop(Transform): ...@@ -147,9 +151,7 @@ class RandomResizedCrop(Transform):
return dict(top=i, left=j, height=h, width=w) return dict(top=i, left=j, height=h, width=w)
def _transform(self, input: Any, params: Dict[str, Any]) -> Any: def _transform(self, input: Any, params: Dict[str, Any]) -> Any:
if isinstance(input, (features.BoundingBox, features.SegmentationMask)): if isinstance(input, features.Image):
raise TypeError(f"{type(input).__name__}'s are not supported by {type(self).__name__}()")
elif isinstance(input, features.Image):
output = F.resized_crop_image_tensor( output = F.resized_crop_image_tensor(
input, **params, size=list(self.size), interpolation=self.interpolation input, **params, size=list(self.size), interpolation=self.interpolation
) )
...@@ -160,3 +162,9 @@ class RandomResizedCrop(Transform): ...@@ -160,3 +162,9 @@ class RandomResizedCrop(Transform):
return F.resized_crop_image_pil(input, **params, size=list(self.size), interpolation=self.interpolation) return F.resized_crop_image_pil(input, **params, size=list(self.size), interpolation=self.interpolation)
else: else:
return input return input
def forward(self, *inputs: Any) -> Any:
sample = inputs if len(inputs) > 1 else inputs[0]
if has_any(sample, features.BoundingBox, features.SegmentationMask):
raise TypeError(f"BoundingBox'es and SegmentationMask's are not supported by {type(self).__name__}()")
return super().forward(sample)
...@@ -59,7 +59,10 @@ class ConvertImageColorSpace(Transform): ...@@ -59,7 +59,10 @@ class ConvertImageColorSpace(Transform):
return features.Image.new_like(input, output, color_space=self.color_space) return features.Image.new_like(input, output, color_space=self.color_space)
elif isinstance(input, torch.Tensor): elif isinstance(input, torch.Tensor):
if self.old_color_space is None: if self.old_color_space is None:
raise RuntimeError("") raise RuntimeError(
f"In order to convert vanilla tensor images, `{type(self).__name__}(...)` "
f"needs to be constructed with the `old_color_space=...` parameter."
)
return F.convert_image_color_space_tensor( return F.convert_image_color_space_tensor(
input, old_color_space=self.old_color_space, new_color_space=self.color_space input, old_color_space=self.old_color_space, new_color_space=self.color_space
......
from typing import Any, Optional, Tuple, Union from typing import Any, Optional, Tuple, Union, Type, Iterator
import PIL.Image import PIL.Image
import torch import torch
...@@ -34,3 +34,15 @@ def get_image_dimensions(image: Union[PIL.Image.Image, torch.Tensor, features.Im ...@@ -34,3 +34,15 @@ def get_image_dimensions(image: Union[PIL.Image.Image, torch.Tensor, features.Im
else: else:
raise TypeError(f"unable to get image dimensions from object of type {type(image).__name__}") raise TypeError(f"unable to get image dimensions from object of type {type(image).__name__}")
return channels, height, width return channels, height, width
def _extract_types(sample: Any) -> Iterator[Type]:
return query_recursively(lambda id, input: type(input), sample)
def has_any(sample: Any, *types: Type) -> bool:
return any(issubclass(type, types) for type in _extract_types(sample))
def has_all(sample: Any, *types: Type) -> bool:
return not bool(set(types) - set(_extract_types(sample)))
...@@ -58,7 +58,9 @@ def convert_bounding_box_format( ...@@ -58,7 +58,9 @@ def convert_bounding_box_format(
def _grayscale_to_rgb_tensor(grayscale: torch.Tensor) -> torch.Tensor: def _grayscale_to_rgb_tensor(grayscale: torch.Tensor) -> torch.Tensor:
return grayscale.expand(3, 1, 1) repeats = [1] * grayscale.ndim
repeats[-3] = 3
return grayscale.repeat(repeats)
def convert_image_color_space_tensor( def convert_image_color_space_tensor(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment