Unverified Commit 50608fbc authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

fix vanilla tensor image detection (#5518)



* fix vanilla tensor image detection

* fix naming
Co-authored-by: default avatarVasilis Vryniotis <datumbox@users.noreply.github.com>
parent e2bb6baa
......@@ -7,7 +7,7 @@ import torch
from torchvision.prototype import features
from torchvision.prototype.transforms import Transform, functional as F
from ._utils import query_image, get_image_dimensions, has_all, has_any
from ._utils import query_image, get_image_dimensions, has_all, has_any, is_simple_tensor
class RandomErasing(Transform):
......@@ -90,7 +90,7 @@ class RandomErasing(Transform):
if isinstance(input, features.Image):
output = F.erase_image_tensor(input, **params)
return features.Image.new_like(input, output)
elif isinstance(input, torch.Tensor):
elif is_simple_tensor(input):
return F.erase_image_tensor(input, **params)
else:
return input
......
......@@ -8,7 +8,7 @@ from torchvision.prototype.transforms import Transform, InterpolationMode, AutoA
from torchvision.prototype.utils._internal import query_recursively
from torchvision.transforms.functional import pil_to_tensor, to_pil_image
from ._utils import get_image_dimensions
from ._utils import get_image_dimensions, is_simple_tensor
K = TypeVar("K")
V = TypeVar("V")
......@@ -89,7 +89,7 @@ class _AutoAugmentBase(Transform):
if isinstance(input, features.Image):
output = image_tensor_kernel(input, *args, **kwargs)
return features.Image.new_like(input, output)
elif isinstance(input, torch.Tensor):
elif is_simple_tensor(input):
return image_tensor_kernel(input, *args, **kwargs)
else: # isinstance(input, PIL.Image.Image):
return image_pil_kernel(input, *args, **kwargs)
......
......@@ -8,7 +8,7 @@ from torchvision.prototype import features
from torchvision.prototype.transforms import Transform, InterpolationMode, functional as F
from torchvision.transforms.transforms import _setup_size, _interpolation_modes_from_int
from ._utils import query_image, get_image_dimensions, has_any
from ._utils import query_image, get_image_dimensions, has_any, is_simple_tensor
class HorizontalFlip(Transform):
......@@ -21,7 +21,7 @@ class HorizontalFlip(Transform):
return features.BoundingBox.new_like(input, output)
elif isinstance(input, PIL.Image.Image):
return F.horizontal_flip_image_pil(input)
elif isinstance(input, torch.Tensor):
elif is_simple_tensor(input):
return F.horizontal_flip_image_tensor(input)
else:
return input
......@@ -49,7 +49,7 @@ class Resize(Transform):
return features.BoundingBox.new_like(input, output, image_size=cast(Tuple[int, int], tuple(self.size)))
elif isinstance(input, PIL.Image.Image):
return F.resize_image_pil(input, self.size, interpolation=self.interpolation)
elif isinstance(input, torch.Tensor):
elif is_simple_tensor(input):
return F.resize_image_tensor(input, self.size, interpolation=self.interpolation)
else:
return input
......@@ -64,7 +64,7 @@ class CenterCrop(Transform):
if isinstance(input, features.Image):
output = F.center_crop_image_tensor(input, self.output_size)
return features.Image.new_like(input, output)
elif isinstance(input, torch.Tensor):
elif is_simple_tensor(input):
return F.center_crop_image_tensor(input, self.output_size)
elif isinstance(input, PIL.Image.Image):
return F.center_crop_image_pil(input, self.output_size)
......@@ -156,7 +156,7 @@ class RandomResizedCrop(Transform):
input, **params, size=list(self.size), interpolation=self.interpolation
)
return features.Image.new_like(input, output)
elif isinstance(input, torch.Tensor):
elif is_simple_tensor(input):
return F.resized_crop_image_tensor(input, **params, size=list(self.size), interpolation=self.interpolation)
elif isinstance(input, PIL.Image.Image):
return F.resized_crop_image_pil(input, **params, size=list(self.size), interpolation=self.interpolation)
......
......@@ -6,6 +6,8 @@ from torchvision.prototype import features
from torchvision.prototype.transforms import Transform, functional as F
from torchvision.transforms.functional import convert_image_dtype
from ._utils import is_simple_tensor
class ConvertBoundingBoxFormat(Transform):
def __init__(self, format: Union[str, features.BoundingBoxFormat]) -> None:
......@@ -15,7 +17,7 @@ class ConvertBoundingBoxFormat(Transform):
self.format = format
def _transform(self, input: Any, params: Dict[str, Any]) -> Any:
if type(input) is features.BoundingBox:
if isinstance(input, features.BoundingBox):
output = F.convert_bounding_box_format(input, old_format=input.format, new_format=params["format"])
return features.BoundingBox.new_like(input, output, format=params["format"])
else:
......@@ -28,9 +30,11 @@ class ConvertImageDtype(Transform):
self.dtype = dtype
def _transform(self, input: Any, params: Dict[str, Any]) -> Any:
if type(input) is features.Image:
if isinstance(input, features.Image):
output = convert_image_dtype(input, dtype=self.dtype)
return features.Image.new_like(input, output, dtype=self.dtype)
elif is_simple_tensor(input):
return convert_image_dtype(input, dtype=self.dtype)
else:
return input
......@@ -57,7 +61,7 @@ class ConvertImageColorSpace(Transform):
input, old_color_space=input.color_space, new_color_space=self.color_space
)
return features.Image.new_like(input, output, color_space=self.color_space)
elif isinstance(input, torch.Tensor):
elif is_simple_tensor(input):
if self.old_color_space is None:
raise RuntimeError(
f"In order to convert vanilla tensor images, `{type(self).__name__}(...)` "
......
......@@ -6,7 +6,7 @@ from torchvision.prototype.transforms import Transform, functional as F
class DecodeImage(Transform):
def _transform(self, input: Any, params: Dict[str, Any]) -> Any:
if type(input) is features.EncodedImage:
if isinstance(input, features.EncodedImage):
output = F.decode_image_with_pil(input)
return features.Image(output)
else:
......@@ -19,7 +19,7 @@ class LabelToOneHot(Transform):
self.num_categories = num_categories
def _transform(self, input: Any, params: Dict[str, Any]) -> Any:
if type(input) is features.Label:
if isinstance(input, features.Label):
num_categories = self.num_categories
if num_categories == -1 and input.categories is not None:
num_categories = len(input.categories)
......
......@@ -46,3 +46,7 @@ def has_any(sample: Any, *types: Type) -> bool:
def has_all(sample: Any, *types: Type) -> bool:
return not bool(set(types) - set(_extract_types(sample)))
def is_simple_tensor(input: Any) -> bool:
return isinstance(input, torch.Tensor) and not isinstance(input, features._Feature)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment