Unverified Commit e13206d9 authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

add option to fail a transform on certain types rather than passthrough (#5432)



* add option to fail a transform on certain types rather than passthrough

* address comments
Co-authored-by: default avatarVasilis Vryniotis <datumbox@users.noreply.github.com>
parent 2fe0c2d0
......@@ -3,7 +3,9 @@ import numbers
import warnings
from typing import Any, Dict, Tuple
import PIL.Image
import torch
from torchvision.prototype import features
from torchvision.prototype.transforms import Transform, functional as F
from ._utils import query_image
......@@ -11,6 +13,7 @@ from ._utils import query_image
class RandomErasing(Transform):
_DISPATCHER = F.erase
_FAIL_TYPES = {PIL.Image.Image, features.BoundingBox, features.SegmentationMask}
def __init__(
self,
......@@ -98,6 +101,7 @@ class RandomErasing(Transform):
class RandomMixup(Transform):
_DISPATCHER = F.mixup
_FAIL_TYPES = {features.BoundingBox, features.SegmentationMask}
def __init__(self, *, alpha: float) -> None:
super().__init__()
......@@ -110,6 +114,7 @@ class RandomMixup(Transform):
class RandomCutmix(Transform):
_DISPATCHER = F.cutmix
_FAIL_TYPES = {features.BoundingBox, features.SegmentationMask}
def __init__(self, *, alpha: float) -> None:
super().__init__()
......
......@@ -79,9 +79,6 @@ class _AutoAugmentBase(Transform):
"Invert": lambda input, magnitude, interpolation, fill: F.invert(input),
}
def _is_supported(self, obj: Any) -> bool:
return type(obj) in {features.Image, torch.Tensor} or isinstance(obj, PIL.Image.Image)
def _get_params(self, sample: Any) -> Dict[str, Any]:
image = query_image(sample)
num_channels = F.get_image_num_channels(image)
......@@ -103,10 +100,12 @@ class _AutoAugmentBase(Transform):
dispatcher = self._DISPATCHER_MAP[transform_id]
def transform(input: Any) -> Any:
if not self._is_supported(input):
return input
if type(input) in {features.Image, torch.Tensor} or isinstance(input, PIL.Image.Image):
return dispatcher(input, magnitude, params["interpolation"], params["fill"])
elif type(input) in {features.BoundingBox, features.SegmentationMask}:
raise TypeError(f"{type(input)} is not supported by {type(self).__name__}()")
else:
return input
return apply_recursively(transform, sample)
......
......@@ -3,6 +3,7 @@ import warnings
from typing import Any, Dict, List, Union, Sequence, Tuple, cast
import torch
from torchvision.prototype import features
from torchvision.prototype.transforms import Transform, InterpolationMode, functional as F
from torchvision.transforms.transforms import _setup_size, _interpolation_modes_from_int
......@@ -31,6 +32,7 @@ class Resize(Transform):
class CenterCrop(Transform):
_DISPATCHER = F.center_crop
_FAIL_TYPES = {features.BoundingBox, features.SegmentationMask}
def __init__(self, output_size: List[int]):
super().__init__()
......@@ -42,6 +44,7 @@ class CenterCrop(Transform):
class RandomResizedCrop(Transform):
_DISPATCHER = F.resized_crop
_FAIL_TYPES = {features.BoundingBox, features.SegmentationMask}
def __init__(
self,
......
import enum
import functools
from typing import Any, Dict, Optional
from typing import Any, Dict, Optional, Set, Type
from torch import nn
from torchvision.prototype.utils._internal import apply_recursively
......@@ -11,6 +11,7 @@ from .functional._utils import Dispatcher
class Transform(nn.Module):
_DISPATCHER: Optional[Dispatcher] = None
_FAIL_TYPES: Set[Type] = set()
def __init__(self) -> None:
super().__init__()
......@@ -23,10 +24,12 @@ class Transform(nn.Module):
if not self._DISPATCHER:
raise NotImplementedError()
if input not in self._DISPATCHER:
return input
if input in self._DISPATCHER:
return self._DISPATCHER(input, **params)
elif type(input) in self._FAIL_TYPES:
raise TypeError(f"{type(input)} is not supported by {type(self).__name__}()")
else:
return input
def forward(self, *inputs: Any, params: Optional[Dict[str, Any]] = None) -> Any:
sample = inputs if len(inputs) > 1 else inputs[0]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment