Unverified Commit e13206d9 authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

add option to fail a transform on certain types rather than passthrough (#5432)



* add option to fail a transform on certain types rather than passthrough

* address comments
Co-authored-by: default avatarVasilis Vryniotis <datumbox@users.noreply.github.com>
parent 2fe0c2d0
...@@ -3,7 +3,9 @@ import numbers ...@@ -3,7 +3,9 @@ import numbers
import warnings import warnings
from typing import Any, Dict, Tuple from typing import Any, Dict, Tuple
import PIL.Image
import torch import torch
from torchvision.prototype import features
from torchvision.prototype.transforms import Transform, functional as F from torchvision.prototype.transforms import Transform, functional as F
from ._utils import query_image from ._utils import query_image
...@@ -11,6 +13,7 @@ from ._utils import query_image ...@@ -11,6 +13,7 @@ from ._utils import query_image
class RandomErasing(Transform): class RandomErasing(Transform):
_DISPATCHER = F.erase _DISPATCHER = F.erase
_FAIL_TYPES = {PIL.Image.Image, features.BoundingBox, features.SegmentationMask}
def __init__( def __init__(
self, self,
...@@ -98,6 +101,7 @@ class RandomErasing(Transform): ...@@ -98,6 +101,7 @@ class RandomErasing(Transform):
class RandomMixup(Transform): class RandomMixup(Transform):
_DISPATCHER = F.mixup _DISPATCHER = F.mixup
_FAIL_TYPES = {features.BoundingBox, features.SegmentationMask}
def __init__(self, *, alpha: float) -> None: def __init__(self, *, alpha: float) -> None:
super().__init__() super().__init__()
...@@ -110,6 +114,7 @@ class RandomMixup(Transform): ...@@ -110,6 +114,7 @@ class RandomMixup(Transform):
class RandomCutmix(Transform): class RandomCutmix(Transform):
_DISPATCHER = F.cutmix _DISPATCHER = F.cutmix
_FAIL_TYPES = {features.BoundingBox, features.SegmentationMask}
def __init__(self, *, alpha: float) -> None: def __init__(self, *, alpha: float) -> None:
super().__init__() super().__init__()
......
...@@ -79,9 +79,6 @@ class _AutoAugmentBase(Transform): ...@@ -79,9 +79,6 @@ class _AutoAugmentBase(Transform):
"Invert": lambda input, magnitude, interpolation, fill: F.invert(input), "Invert": lambda input, magnitude, interpolation, fill: F.invert(input),
} }
def _is_supported(self, obj: Any) -> bool:
return type(obj) in {features.Image, torch.Tensor} or isinstance(obj, PIL.Image.Image)
def _get_params(self, sample: Any) -> Dict[str, Any]: def _get_params(self, sample: Any) -> Dict[str, Any]:
image = query_image(sample) image = query_image(sample)
num_channels = F.get_image_num_channels(image) num_channels = F.get_image_num_channels(image)
...@@ -103,10 +100,12 @@ class _AutoAugmentBase(Transform): ...@@ -103,10 +100,12 @@ class _AutoAugmentBase(Transform):
dispatcher = self._DISPATCHER_MAP[transform_id] dispatcher = self._DISPATCHER_MAP[transform_id]
def transform(input: Any) -> Any: def transform(input: Any) -> Any:
if not self._is_supported(input): if type(input) in {features.Image, torch.Tensor} or isinstance(input, PIL.Image.Image):
return input
return dispatcher(input, magnitude, params["interpolation"], params["fill"]) return dispatcher(input, magnitude, params["interpolation"], params["fill"])
elif type(input) in {features.BoundingBox, features.SegmentationMask}:
raise TypeError(f"{type(input)} is not supported by {type(self).__name__}()")
else:
return input
return apply_recursively(transform, sample) return apply_recursively(transform, sample)
......
...@@ -3,6 +3,7 @@ import warnings ...@@ -3,6 +3,7 @@ import warnings
from typing import Any, Dict, List, Union, Sequence, Tuple, cast from typing import Any, Dict, List, Union, Sequence, Tuple, cast
import torch import torch
from torchvision.prototype import features
from torchvision.prototype.transforms import Transform, InterpolationMode, functional as F from torchvision.prototype.transforms import Transform, InterpolationMode, functional as F
from torchvision.transforms.transforms import _setup_size, _interpolation_modes_from_int from torchvision.transforms.transforms import _setup_size, _interpolation_modes_from_int
...@@ -31,6 +32,7 @@ class Resize(Transform): ...@@ -31,6 +32,7 @@ class Resize(Transform):
class CenterCrop(Transform): class CenterCrop(Transform):
_DISPATCHER = F.center_crop _DISPATCHER = F.center_crop
_FAIL_TYPES = {features.BoundingBox, features.SegmentationMask}
def __init__(self, output_size: List[int]): def __init__(self, output_size: List[int]):
super().__init__() super().__init__()
...@@ -42,6 +44,7 @@ class CenterCrop(Transform): ...@@ -42,6 +44,7 @@ class CenterCrop(Transform):
class RandomResizedCrop(Transform): class RandomResizedCrop(Transform):
_DISPATCHER = F.resized_crop _DISPATCHER = F.resized_crop
_FAIL_TYPES = {features.BoundingBox, features.SegmentationMask}
def __init__( def __init__(
self, self,
......
import enum import enum
import functools import functools
from typing import Any, Dict, Optional from typing import Any, Dict, Optional, Set, Type
from torch import nn from torch import nn
from torchvision.prototype.utils._internal import apply_recursively from torchvision.prototype.utils._internal import apply_recursively
...@@ -11,6 +11,7 @@ from .functional._utils import Dispatcher ...@@ -11,6 +11,7 @@ from .functional._utils import Dispatcher
class Transform(nn.Module): class Transform(nn.Module):
_DISPATCHER: Optional[Dispatcher] = None _DISPATCHER: Optional[Dispatcher] = None
_FAIL_TYPES: Set[Type] = set()
def __init__(self) -> None: def __init__(self) -> None:
super().__init__() super().__init__()
...@@ -23,10 +24,12 @@ class Transform(nn.Module): ...@@ -23,10 +24,12 @@ class Transform(nn.Module):
if not self._DISPATCHER: if not self._DISPATCHER:
raise NotImplementedError() raise NotImplementedError()
if input not in self._DISPATCHER: if input in self._DISPATCHER:
return input
return self._DISPATCHER(input, **params) return self._DISPATCHER(input, **params)
elif type(input) in self._FAIL_TYPES:
raise TypeError(f"{type(input)} is not supported by {type(self).__name__}()")
else:
return input
def forward(self, *inputs: Any, params: Optional[Dict[str, Any]] = None) -> Any: def forward(self, *inputs: Any, params: Optional[Dict[str, Any]] = None) -> Any:
sample = inputs if len(inputs) > 1 else inputs[0] sample = inputs if len(inputs) > 1 else inputs[0]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment