Unverified Commit 64e7460e authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

Prototype transforms cleanup (#5504)



* fix grayscale to RGB for batches

* make unsupported types in auto augment a parameter

* make auto augment kwargs explicit

* add missing error message

* add support for specifying probabilites on RandomChoice

* remove TODO for deprecating p on random transforms

* streamline sample type checking

* address comments

* split image_size into height and width in auto augment
Co-authored-by: default avatarVasilis Vryniotis <datumbox@users.noreply.github.com>
parent b8de0b84
......@@ -7,7 +7,7 @@ import torch
from torchvision.prototype import features
from torchvision.prototype.transforms import Transform, functional as F
from ._utils import query_image, get_image_dimensions
from ._utils import query_image, get_image_dimensions, has_all, has_any
class RandomErasing(Transform):
......@@ -33,7 +33,6 @@ class RandomErasing(Transform):
raise ValueError("Scale should be between 0 and 1")
if p < 0 or p > 1:
raise ValueError("Random erasing probability should be between 0 and 1")
# TODO: deprecate p in favor of wrapping the transform in a RandomApply
self.p = p
self.scale = scale
self.ratio = ratio
......@@ -88,9 +87,7 @@ class RandomErasing(Transform):
return dict(zip("ijhwv", (i, j, h, w, v)))
def _transform(self, input: Any, params: Dict[str, Any]) -> Any:
if isinstance(input, (features.BoundingBox, features.SegmentationMask)):
raise TypeError(f"{type(input).__name__}'s are not supported by {type(self).__name__}()")
elif isinstance(input, features.Image):
if isinstance(input, features.Image):
output = F.erase_image_tensor(input, **params)
return features.Image.new_like(input, output)
elif isinstance(input, torch.Tensor):
......@@ -99,10 +96,13 @@ class RandomErasing(Transform):
return input
def forward(self, *inputs: Any) -> Any:
if torch.rand(1) >= self.p:
return inputs if len(inputs) > 1 else inputs[0]
sample = inputs if len(inputs) > 1 else inputs[0]
if has_any(sample, features.BoundingBox, features.SegmentationMask):
raise TypeError(f"BoundingBox'es and SegmentationMask's are not supported by {type(self).__name__}()")
elif torch.rand(1) >= self.p:
return sample
return super().forward(*inputs)
return super().forward(sample)
class RandomMixup(Transform):
......@@ -115,9 +115,7 @@ class RandomMixup(Transform):
return dict(lam=float(self._dist.sample(())))
def _transform(self, input: Any, params: Dict[str, Any]) -> Any:
if isinstance(input, (features.BoundingBox, features.SegmentationMask)):
raise TypeError(f"{type(input).__name__}'s are not supported by {type(self).__name__}()")
elif isinstance(input, features.Image):
if isinstance(input, features.Image):
output = F.mixup_image_tensor(input, **params)
return features.Image.new_like(input, output)
elif isinstance(input, features.OneHotLabel):
......@@ -126,6 +124,14 @@ class RandomMixup(Transform):
else:
return input
def forward(self, *inputs: Any) -> Any:
sample = inputs if len(inputs) > 1 else inputs[0]
if has_any(sample, features.BoundingBox, features.SegmentationMask):
raise TypeError(f"BoundingBox'es and SegmentationMask's are not supported by {type(self).__name__}()")
elif not has_all(sample, features.Image, features.OneHotLabel):
raise TypeError(f"{type(self).__name__}() is only defined for Image's *and* OneHotLabel's.")
return super().forward(sample)
class RandomCutmix(Transform):
def __init__(self, *, alpha: float) -> None:
......@@ -157,9 +163,7 @@ class RandomCutmix(Transform):
return dict(box=box, lam_adjusted=lam_adjusted)
def _transform(self, input: Any, params: Dict[str, Any]) -> Any:
if isinstance(input, (features.BoundingBox, features.SegmentationMask)):
raise TypeError(f"{type(input).__name__}'s are not supported by {type(self).__name__}()")
elif isinstance(input, features.Image):
if isinstance(input, features.Image):
output = F.cutmix_image_tensor(input, box=params["box"])
return features.Image.new_like(input, output)
elif isinstance(input, features.OneHotLabel):
......@@ -167,3 +171,11 @@ class RandomCutmix(Transform):
return features.OneHotLabel.new_like(input, output)
else:
return input
def forward(self, *inputs: Any) -> Any:
sample = inputs if len(inputs) > 1 else inputs[0]
if has_any(sample, features.BoundingBox, features.SegmentationMask):
raise TypeError(f"BoundingBox'es and SegmentationMask's are not supported by {type(self).__name__}()")
elif not has_all(sample, features.Image, features.OneHotLabel):
raise TypeError(f"{type(self).__name__}() is only defined for Image's *and* OneHotLabel's.")
return super().forward(sample)
import math
from typing import Any, Dict, Tuple, Optional, Callable, List, cast, TypeVar, Union
from typing import Any, Dict, Tuple, Optional, Callable, List, cast, TypeVar, Union, Type
import PIL.Image
import torch
......@@ -39,21 +39,20 @@ class _AutoAugmentBase(Transform):
key = keys[int(torch.randint(len(keys), ()))]
return key, dct[key]
def _check_unsupported(self, input: Any) -> None:
if isinstance(input, (features.BoundingBox, features.SegmentationMask)):
raise TypeError(f"{type(input).__name__}'s are not supported by {type(self).__name__}()")
def _extract_image(
self, sample: Any
self,
sample: Any,
unsupported_types: Tuple[Type, ...] = (features.BoundingBox, features.SegmentationMask),
) -> Tuple[Tuple[Any, ...], Union[PIL.Image.Image, torch.Tensor, features.Image]]:
def fn(
id: Tuple[Any, ...], input: Any
) -> Optional[Tuple[Tuple[Any, ...], Union[PIL.Image.Image, torch.Tensor, features.Image]]]:
if type(input) in {torch.Tensor, features.Image} or isinstance(input, PIL.Image.Image):
return id, input
self._check_unsupported(input)
return None
elif isinstance(input, unsupported_types):
raise TypeError(f"Inputs of type {type(input).__name__} are not supported by {type(self).__name__}()")
else:
return None
images = list(query_recursively(fn, sample))
if not images:
......@@ -200,29 +199,40 @@ class _AutoAugmentBase(Transform):
class AutoAugment(_AutoAugmentBase):
_AUGMENTATION_SPACE = {
"ShearX": (lambda num_bins, image_size: torch.linspace(0.0, 0.3, num_bins), True),
"ShearY": (lambda num_bins, image_size: torch.linspace(0.0, 0.3, num_bins), True),
"TranslateX": (lambda num_bins, image_size: torch.linspace(0.0, 150.0 / 331.0 * image_size[1], num_bins), True),
"TranslateY": (lambda num_bins, image_size: torch.linspace(0.0, 150.0 / 331.0 * image_size[0], num_bins), True),
"Rotate": (lambda num_bins, image_size: torch.linspace(0.0, 30.0, num_bins), True),
"Brightness": (lambda num_bins, image_size: torch.linspace(0.0, 0.9, num_bins), True),
"Color": (lambda num_bins, image_size: torch.linspace(0.0, 0.9, num_bins), True),
"Contrast": (lambda num_bins, image_size: torch.linspace(0.0, 0.9, num_bins), True),
"Sharpness": (lambda num_bins, image_size: torch.linspace(0.0, 0.9, num_bins), True),
"ShearX": (lambda num_bins, height, width: torch.linspace(0.0, 0.3, num_bins), True),
"ShearY": (lambda num_bins, height, width: torch.linspace(0.0, 0.3, num_bins), True),
"TranslateX": (
lambda num_bins, height, width: torch.linspace(0.0, 150.0 / 331.0 * width, num_bins),
True,
),
"TranslateY": (
lambda num_bins, height, width: torch.linspace(0.0, 150.0 / 331.0 * height, num_bins),
True,
),
"Rotate": (lambda num_bins, height, width: torch.linspace(0.0, 30.0, num_bins), True),
"Brightness": (lambda num_bins, height, width: torch.linspace(0.0, 0.9, num_bins), True),
"Color": (lambda num_bins, height, width: torch.linspace(0.0, 0.9, num_bins), True),
"Contrast": (lambda num_bins, height, width: torch.linspace(0.0, 0.9, num_bins), True),
"Sharpness": (lambda num_bins, height, width: torch.linspace(0.0, 0.9, num_bins), True),
"Posterize": (
lambda num_bins, image_size: cast(torch.Tensor, 8 - (torch.arange(num_bins) / ((num_bins - 1) / 4)))
lambda num_bins, height, width: cast(torch.Tensor, 8 - (torch.arange(num_bins) / ((num_bins - 1) / 4)))
.round()
.int(),
False,
),
"Solarize": (lambda num_bins, image_size: torch.linspace(255.0, 0.0, num_bins), False),
"AutoContrast": (lambda num_bins, image_size: None, False),
"Equalize": (lambda num_bins, image_size: None, False),
"Invert": (lambda num_bins, image_size: None, False),
"Solarize": (lambda num_bins, height, width: torch.linspace(255.0, 0.0, num_bins), False),
"AutoContrast": (lambda num_bins, height, width: None, False),
"Equalize": (lambda num_bins, height, width: None, False),
"Invert": (lambda num_bins, height, width: None, False),
}
def __init__(self, policy: AutoAugmentPolicy = AutoAugmentPolicy.IMAGENET, **kwargs: Any) -> None:
super().__init__(**kwargs)
def __init__(
self,
policy: AutoAugmentPolicy = AutoAugmentPolicy.IMAGENET,
interpolation: InterpolationMode = InterpolationMode.NEAREST,
fill: Optional[List[float]] = None,
) -> None:
super().__init__(interpolation=interpolation, fill=fill)
self.policy = policy
self._policies = self._get_policies(policy)
......@@ -331,7 +341,7 @@ class AutoAugment(_AutoAugmentBase):
magnitudes_fn, signed = self._AUGMENTATION_SPACE[transform_id]
magnitudes = magnitudes_fn(10, (height, width))
magnitudes = magnitudes_fn(10, height, width)
if magnitudes is not None:
magnitude = float(magnitudes[magnitude_idx])
if signed and torch.rand(()) <= 0.5:
......@@ -348,29 +358,43 @@ class AutoAugment(_AutoAugmentBase):
class RandAugment(_AutoAugmentBase):
_AUGMENTATION_SPACE = {
"Identity": (lambda num_bins, image_size: None, False),
"ShearX": (lambda num_bins, image_size: torch.linspace(0.0, 0.3, num_bins), True),
"ShearY": (lambda num_bins, image_size: torch.linspace(0.0, 0.3, num_bins), True),
"TranslateX": (lambda num_bins, image_size: torch.linspace(0.0, 150.0 / 331.0 * image_size[1], num_bins), True),
"TranslateY": (lambda num_bins, image_size: torch.linspace(0.0, 150.0 / 331.0 * image_size[0], num_bins), True),
"Rotate": (lambda num_bins, image_size: torch.linspace(0.0, 30.0, num_bins), True),
"Brightness": (lambda num_bins, image_size: torch.linspace(0.0, 0.9, num_bins), True),
"Color": (lambda num_bins, image_size: torch.linspace(0.0, 0.9, num_bins), True),
"Contrast": (lambda num_bins, image_size: torch.linspace(0.0, 0.9, num_bins), True),
"Sharpness": (lambda num_bins, image_size: torch.linspace(0.0, 0.9, num_bins), True),
"Identity": (lambda num_bins, height, width: None, False),
"ShearX": (lambda num_bins, height, width: torch.linspace(0.0, 0.3, num_bins), True),
"ShearY": (lambda num_bins, height, width: torch.linspace(0.0, 0.3, num_bins), True),
"TranslateX": (
lambda num_bins, height, width: torch.linspace(0.0, 150.0 / 331.0 * width, num_bins),
True,
),
"TranslateY": (
lambda num_bins, height, width: torch.linspace(0.0, 150.0 / 331.0 * height, num_bins),
True,
),
"Rotate": (lambda num_bins, height, width: torch.linspace(0.0, 30.0, num_bins), True),
"Brightness": (lambda num_bins, height, width: torch.linspace(0.0, 0.9, num_bins), True),
"Color": (lambda num_bins, height, width: torch.linspace(0.0, 0.9, num_bins), True),
"Contrast": (lambda num_bins, height, width: torch.linspace(0.0, 0.9, num_bins), True),
"Sharpness": (lambda num_bins, height, width: torch.linspace(0.0, 0.9, num_bins), True),
"Posterize": (
lambda num_bins, image_size: cast(torch.Tensor, 8 - (torch.arange(num_bins) / ((num_bins - 1) / 4)))
lambda num_bins, height, width: cast(torch.Tensor, 8 - (torch.arange(num_bins) / ((num_bins - 1) / 4)))
.round()
.int(),
False,
),
"Solarize": (lambda num_bins, image_size: torch.linspace(255.0, 0.0, num_bins), False),
"AutoContrast": (lambda num_bins, image_size: None, False),
"Equalize": (lambda num_bins, image_size: None, False),
"Solarize": (lambda num_bins, height, width: torch.linspace(255.0, 0.0, num_bins), False),
"AutoContrast": (lambda num_bins, height, width: None, False),
"Equalize": (lambda num_bins, height, width: None, False),
}
def __init__(self, *, num_ops: int = 2, magnitude: int = 9, num_magnitude_bins: int = 31, **kwargs: Any) -> None:
super().__init__(**kwargs)
def __init__(
self,
*,
num_ops: int = 2,
magnitude: int = 9,
num_magnitude_bins: int = 31,
interpolation: InterpolationMode = InterpolationMode.NEAREST,
fill: Optional[List[float]] = None,
) -> None:
super().__init__(interpolation=interpolation, fill=fill)
self.num_ops = num_ops
self.magnitude = magnitude
self.num_magnitude_bins = num_magnitude_bins
......@@ -385,7 +409,7 @@ class RandAugment(_AutoAugmentBase):
for _ in range(self.num_ops):
transform_id, (magnitudes_fn, signed) = self._get_random_item(self._AUGMENTATION_SPACE)
magnitudes = magnitudes_fn(self.num_magnitude_bins, (height, width))
magnitudes = magnitudes_fn(self.num_magnitude_bins, height, width)
if magnitudes is not None:
magnitude = float(magnitudes[int(torch.randint(self.num_magnitude_bins, ()))])
if signed and torch.rand(()) <= 0.5:
......@@ -402,29 +426,35 @@ class RandAugment(_AutoAugmentBase):
class TrivialAugmentWide(_AutoAugmentBase):
_AUGMENTATION_SPACE = {
"Identity": (lambda num_bins, image_size: None, False),
"ShearX": (lambda num_bins, image_size: torch.linspace(0.0, 0.99, num_bins), True),
"ShearY": (lambda num_bins, image_size: torch.linspace(0.0, 0.99, num_bins), True),
"TranslateX": (lambda num_bins, image_size: torch.linspace(0.0, 32.0, num_bins), True),
"TranslateY": (lambda num_bins, image_size: torch.linspace(0.0, 32.0, num_bins), True),
"Rotate": (lambda num_bins, image_size: torch.linspace(0.0, 135.0, num_bins), True),
"Brightness": (lambda num_bins, image_size: torch.linspace(0.0, 0.99, num_bins), True),
"Color": (lambda num_bins, image_size: torch.linspace(0.0, 0.99, num_bins), True),
"Contrast": (lambda num_bins, image_size: torch.linspace(0.0, 0.99, num_bins), True),
"Sharpness": (lambda num_bins, image_size: torch.linspace(0.0, 0.99, num_bins), True),
"Identity": (lambda num_bins, height, width: None, False),
"ShearX": (lambda num_bins, height, width: torch.linspace(0.0, 0.99, num_bins), True),
"ShearY": (lambda num_bins, height, width: torch.linspace(0.0, 0.99, num_bins), True),
"TranslateX": (lambda num_bins, height, width: torch.linspace(0.0, 32.0, num_bins), True),
"TranslateY": (lambda num_bins, height, width: torch.linspace(0.0, 32.0, num_bins), True),
"Rotate": (lambda num_bins, height, width: torch.linspace(0.0, 135.0, num_bins), True),
"Brightness": (lambda num_bins, height, width: torch.linspace(0.0, 0.99, num_bins), True),
"Color": (lambda num_bins, height, width: torch.linspace(0.0, 0.99, num_bins), True),
"Contrast": (lambda num_bins, height, width: torch.linspace(0.0, 0.99, num_bins), True),
"Sharpness": (lambda num_bins, height, width: torch.linspace(0.0, 0.99, num_bins), True),
"Posterize": (
lambda num_bins, image_size: cast(torch.Tensor, 8 - (torch.arange(num_bins) / ((num_bins - 1) / 6)))
lambda num_bins, height, width: cast(torch.Tensor, 8 - (torch.arange(num_bins) / ((num_bins - 1) / 6)))
.round()
.int(),
False,
),
"Solarize": (lambda num_bins, image_size: torch.linspace(255.0, 0.0, num_bins), False),
"AutoContrast": (lambda num_bins, image_size: None, False),
"Equalize": (lambda num_bins, image_size: None, False),
"Solarize": (lambda num_bins, height, width: torch.linspace(255.0, 0.0, num_bins), False),
"AutoContrast": (lambda num_bins, height, width: None, False),
"Equalize": (lambda num_bins, height, width: None, False),
}
def __init__(self, *, num_magnitude_bins: int = 31, **kwargs: Any):
super().__init__(**kwargs)
def __init__(
self,
*,
num_magnitude_bins: int = 31,
interpolation: InterpolationMode = InterpolationMode.NEAREST,
fill: Optional[List[float]] = None,
):
super().__init__(interpolation=interpolation, fill=fill)
self.num_magnitude_bins = num_magnitude_bins
def forward(self, *inputs: Any) -> Any:
......@@ -436,7 +466,7 @@ class TrivialAugmentWide(_AutoAugmentBase):
transform_id, (magnitudes_fn, signed) = self._get_random_item(self._AUGMENTATION_SPACE)
magnitudes = magnitudes_fn(self.num_magnitude_bins, (height, width))
magnitudes = magnitudes_fn(self.num_magnitude_bins, height, width)
if magnitudes is not None:
magnitude = float(magnitudes[int(torch.randint(self.num_magnitude_bins, ()))])
if signed and torch.rand(()) <= 0.5:
......@@ -450,27 +480,27 @@ class TrivialAugmentWide(_AutoAugmentBase):
class AugMix(_AutoAugmentBase):
_PARTIAL_AUGMENTATION_SPACE = {
"ShearX": (lambda num_bins, image_size: torch.linspace(0.0, 0.3, num_bins), True),
"ShearY": (lambda num_bins, image_size: torch.linspace(0.0, 0.3, num_bins), True),
"TranslateX": (lambda num_bins, image_size: torch.linspace(0.0, image_size[1] / 3.0, num_bins), True),
"TranslateY": (lambda num_bins, image_size: torch.linspace(0.0, image_size[0] / 3.0, num_bins), True),
"Rotate": (lambda num_bins, image_size: torch.linspace(0.0, 30.0, num_bins), True),
"ShearX": (lambda num_bins, height, width: torch.linspace(0.0, 0.3, num_bins), True),
"ShearY": (lambda num_bins, height, width: torch.linspace(0.0, 0.3, num_bins), True),
"TranslateX": (lambda num_bins, height, width: torch.linspace(0.0, width / 3.0, num_bins), True),
"TranslateY": (lambda num_bins, height, width: torch.linspace(0.0, height / 3.0, num_bins), True),
"Rotate": (lambda num_bins, height, width: torch.linspace(0.0, 30.0, num_bins), True),
"Posterize": (
lambda num_bins, image_size: cast(torch.Tensor, 4 - (torch.arange(num_bins) / ((num_bins - 1) / 4)))
lambda num_bins, height, width: cast(torch.Tensor, 4 - (torch.arange(num_bins) / ((num_bins - 1) / 4)))
.round()
.int(),
False,
),
"Solarize": (lambda num_bins, image_size: torch.linspace(255.0, 0.0, num_bins), False),
"AutoContrast": (lambda num_bins, image_size: None, False),
"Equalize": (lambda num_bins, image_size: None, False),
"Solarize": (lambda num_bins, height, width: torch.linspace(255.0, 0.0, num_bins), False),
"AutoContrast": (lambda num_bins, height, width: None, False),
"Equalize": (lambda num_bins, height, width: None, False),
}
_AUGMENTATION_SPACE: Dict[str, Tuple[Callable[[int, Tuple[int, int]], Optional[torch.Tensor]], bool]] = {
_AUGMENTATION_SPACE: Dict[str, Tuple[Callable[[int, int, int], Optional[torch.Tensor]], bool]] = {
**_PARTIAL_AUGMENTATION_SPACE,
"Brightness": (lambda num_bins, image_size: torch.linspace(0.0, 0.9, num_bins), True),
"Color": (lambda num_bins, image_size: torch.linspace(0.0, 0.9, num_bins), True),
"Contrast": (lambda num_bins, image_size: torch.linspace(0.0, 0.9, num_bins), True),
"Sharpness": (lambda num_bins, image_size: torch.linspace(0.0, 0.9, num_bins), True),
"Brightness": (lambda num_bins, height, width: torch.linspace(0.0, 0.9, num_bins), True),
"Color": (lambda num_bins, height, width: torch.linspace(0.0, 0.9, num_bins), True),
"Contrast": (lambda num_bins, height, width: torch.linspace(0.0, 0.9, num_bins), True),
"Sharpness": (lambda num_bins, height, width: torch.linspace(0.0, 0.9, num_bins), True),
}
def __init__(
......@@ -480,9 +510,10 @@ class AugMix(_AutoAugmentBase):
chain_depth: int = -1,
alpha: float = 1.0,
all_ops: bool = True,
**kwargs: Any,
interpolation: InterpolationMode = InterpolationMode.BILINEAR,
fill: Optional[List[float]] = None,
) -> None:
super().__init__(**kwargs)
super().__init__(interpolation=interpolation, fill=fill)
self._PARAMETER_MAX = 10
if not (1 <= severity <= self._PARAMETER_MAX):
raise ValueError(f"The severity must be between [1, {self._PARAMETER_MAX}]. Got {severity} instead.")
......@@ -531,7 +562,7 @@ class AugMix(_AutoAugmentBase):
for _ in range(depth):
transform_id, (magnitudes_fn, signed) = self._get_random_item(augmentation_space)
magnitudes = magnitudes_fn(self._PARAMETER_MAX, (height, width))
magnitudes = magnitudes_fn(self._PARAMETER_MAX, height, width)
if magnitudes is not None:
magnitude = float(magnitudes[int(torch.randint(self.severity, ()))])
if signed and torch.rand(()) <= 0.5:
......
from typing import Any
from typing import Any, Optional, List
import torch
......@@ -37,14 +37,26 @@ class RandomApply(Transform):
class RandomChoice(Transform):
def __init__(self, *transforms: Transform) -> None:
def __init__(self, *transforms: Transform, probabilities: Optional[List[float]] = None) -> None:
if probabilities is None:
probabilities = [1] * len(transforms)
elif len(probabilities) != len(transforms):
raise ValueError(
f"The number of probabilities doesn't match the number of transforms: "
f"{len(probabilities)} != {len(transforms)}"
)
super().__init__()
self.transforms = transforms
for idx, transform in enumerate(transforms):
self.add_module(str(idx), transform)
total = sum(probabilities)
self.probabilities = [p / total for p in probabilities]
def forward(self, *inputs: Any) -> Any:
idx = int(torch.randint(len(self.transforms), size=()))
idx = int(torch.multinomial(torch.tensor(self.probabilities), 1))
transform = self.transforms[idx]
return transform(*inputs)
......
......@@ -8,7 +8,7 @@ from torchvision.prototype import features
from torchvision.prototype.transforms import Transform, InterpolationMode, functional as F
from torchvision.transforms.transforms import _setup_size, _interpolation_modes_from_int
from ._utils import query_image, get_image_dimensions
from ._utils import query_image, get_image_dimensions, has_any
class HorizontalFlip(Transform):
......@@ -61,9 +61,7 @@ class CenterCrop(Transform):
self.output_size = output_size
def _transform(self, input: Any, params: Dict[str, Any]) -> Any:
if isinstance(input, (features.BoundingBox, features.SegmentationMask)):
raise TypeError(f"{type(input).__name__}'s are not supported by {type(self).__name__}()")
elif isinstance(input, features.Image):
if isinstance(input, features.Image):
output = F.center_crop_image_tensor(input, self.output_size)
return features.Image.new_like(input, output)
elif isinstance(input, torch.Tensor):
......@@ -73,6 +71,12 @@ class CenterCrop(Transform):
else:
return input
def forward(self, *inputs: Any) -> Any:
sample = inputs if len(inputs) > 1 else inputs[0]
if has_any(sample, features.BoundingBox, features.SegmentationMask):
raise TypeError(f"BoundingBox'es and SegmentationMask's are not supported by {type(self).__name__}()")
return super().forward(sample)
class RandomResizedCrop(Transform):
def __init__(
......@@ -147,9 +151,7 @@ class RandomResizedCrop(Transform):
return dict(top=i, left=j, height=h, width=w)
def _transform(self, input: Any, params: Dict[str, Any]) -> Any:
if isinstance(input, (features.BoundingBox, features.SegmentationMask)):
raise TypeError(f"{type(input).__name__}'s are not supported by {type(self).__name__}()")
elif isinstance(input, features.Image):
if isinstance(input, features.Image):
output = F.resized_crop_image_tensor(
input, **params, size=list(self.size), interpolation=self.interpolation
)
......@@ -160,3 +162,9 @@ class RandomResizedCrop(Transform):
return F.resized_crop_image_pil(input, **params, size=list(self.size), interpolation=self.interpolation)
else:
return input
def forward(self, *inputs: Any) -> Any:
sample = inputs if len(inputs) > 1 else inputs[0]
if has_any(sample, features.BoundingBox, features.SegmentationMask):
raise TypeError(f"BoundingBox'es and SegmentationMask's are not supported by {type(self).__name__}()")
return super().forward(sample)
......@@ -59,7 +59,10 @@ class ConvertImageColorSpace(Transform):
return features.Image.new_like(input, output, color_space=self.color_space)
elif isinstance(input, torch.Tensor):
if self.old_color_space is None:
raise RuntimeError("")
raise RuntimeError(
f"In order to convert vanilla tensor images, `{type(self).__name__}(...)` "
f"needs to be constructed with the `old_color_space=...` parameter."
)
return F.convert_image_color_space_tensor(
input, old_color_space=self.old_color_space, new_color_space=self.color_space
......
from typing import Any, Optional, Tuple, Union
from typing import Any, Optional, Tuple, Union, Type, Iterator
import PIL.Image
import torch
......@@ -34,3 +34,15 @@ def get_image_dimensions(image: Union[PIL.Image.Image, torch.Tensor, features.Im
else:
raise TypeError(f"unable to get image dimensions from object of type {type(image).__name__}")
return channels, height, width
def _extract_types(sample: Any) -> Iterator[Type]:
return query_recursively(lambda id, input: type(input), sample)
def has_any(sample: Any, *types: Type) -> bool:
return any(issubclass(type, types) for type in _extract_types(sample))
def has_all(sample: Any, *types: Type) -> bool:
return not bool(set(types) - set(_extract_types(sample)))
......@@ -58,7 +58,9 @@ def convert_bounding_box_format(
def _grayscale_to_rgb_tensor(grayscale: torch.Tensor) -> torch.Tensor:
return grayscale.expand(3, 1, 1)
repeats = [1] * grayscale.ndim
repeats[-3] = 3
return grayscale.repeat(repeats)
def convert_image_color_space_tensor(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment