"git@developer.sourcefind.cn:OpenDAS/ollama.git" did not exist on "35fda7b4af556e7eeef2b5dcb3638435382b2576"
Unverified Commit 50608fbc authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

fix vanilla tensor image detection (#5518)



* fix vanilla tensor image detection

* fix naming
Co-authored-by: default avatarVasilis Vryniotis <datumbox@users.noreply.github.com>
parent e2bb6baa
...@@ -7,7 +7,7 @@ import torch ...@@ -7,7 +7,7 @@ import torch
from torchvision.prototype import features from torchvision.prototype import features
from torchvision.prototype.transforms import Transform, functional as F from torchvision.prototype.transforms import Transform, functional as F
from ._utils import query_image, get_image_dimensions, has_all, has_any from ._utils import query_image, get_image_dimensions, has_all, has_any, is_simple_tensor
class RandomErasing(Transform): class RandomErasing(Transform):
...@@ -90,7 +90,7 @@ class RandomErasing(Transform): ...@@ -90,7 +90,7 @@ class RandomErasing(Transform):
if isinstance(input, features.Image): if isinstance(input, features.Image):
output = F.erase_image_tensor(input, **params) output = F.erase_image_tensor(input, **params)
return features.Image.new_like(input, output) return features.Image.new_like(input, output)
elif isinstance(input, torch.Tensor): elif is_simple_tensor(input):
return F.erase_image_tensor(input, **params) return F.erase_image_tensor(input, **params)
else: else:
return input return input
......
...@@ -8,7 +8,7 @@ from torchvision.prototype.transforms import Transform, InterpolationMode, AutoA ...@@ -8,7 +8,7 @@ from torchvision.prototype.transforms import Transform, InterpolationMode, AutoA
from torchvision.prototype.utils._internal import query_recursively from torchvision.prototype.utils._internal import query_recursively
from torchvision.transforms.functional import pil_to_tensor, to_pil_image from torchvision.transforms.functional import pil_to_tensor, to_pil_image
from ._utils import get_image_dimensions from ._utils import get_image_dimensions, is_simple_tensor
K = TypeVar("K") K = TypeVar("K")
V = TypeVar("V") V = TypeVar("V")
...@@ -89,7 +89,7 @@ class _AutoAugmentBase(Transform): ...@@ -89,7 +89,7 @@ class _AutoAugmentBase(Transform):
if isinstance(input, features.Image): if isinstance(input, features.Image):
output = image_tensor_kernel(input, *args, **kwargs) output = image_tensor_kernel(input, *args, **kwargs)
return features.Image.new_like(input, output) return features.Image.new_like(input, output)
elif isinstance(input, torch.Tensor): elif is_simple_tensor(input):
return image_tensor_kernel(input, *args, **kwargs) return image_tensor_kernel(input, *args, **kwargs)
else: # isinstance(input, PIL.Image.Image): else: # isinstance(input, PIL.Image.Image):
return image_pil_kernel(input, *args, **kwargs) return image_pil_kernel(input, *args, **kwargs)
......
...@@ -8,7 +8,7 @@ from torchvision.prototype import features ...@@ -8,7 +8,7 @@ from torchvision.prototype import features
from torchvision.prototype.transforms import Transform, InterpolationMode, functional as F from torchvision.prototype.transforms import Transform, InterpolationMode, functional as F
from torchvision.transforms.transforms import _setup_size, _interpolation_modes_from_int from torchvision.transforms.transforms import _setup_size, _interpolation_modes_from_int
from ._utils import query_image, get_image_dimensions, has_any from ._utils import query_image, get_image_dimensions, has_any, is_simple_tensor
class HorizontalFlip(Transform): class HorizontalFlip(Transform):
...@@ -21,7 +21,7 @@ class HorizontalFlip(Transform): ...@@ -21,7 +21,7 @@ class HorizontalFlip(Transform):
return features.BoundingBox.new_like(input, output) return features.BoundingBox.new_like(input, output)
elif isinstance(input, PIL.Image.Image): elif isinstance(input, PIL.Image.Image):
return F.horizontal_flip_image_pil(input) return F.horizontal_flip_image_pil(input)
elif isinstance(input, torch.Tensor): elif is_simple_tensor(input):
return F.horizontal_flip_image_tensor(input) return F.horizontal_flip_image_tensor(input)
else: else:
return input return input
...@@ -49,7 +49,7 @@ class Resize(Transform): ...@@ -49,7 +49,7 @@ class Resize(Transform):
return features.BoundingBox.new_like(input, output, image_size=cast(Tuple[int, int], tuple(self.size))) return features.BoundingBox.new_like(input, output, image_size=cast(Tuple[int, int], tuple(self.size)))
elif isinstance(input, PIL.Image.Image): elif isinstance(input, PIL.Image.Image):
return F.resize_image_pil(input, self.size, interpolation=self.interpolation) return F.resize_image_pil(input, self.size, interpolation=self.interpolation)
elif isinstance(input, torch.Tensor): elif is_simple_tensor(input):
return F.resize_image_tensor(input, self.size, interpolation=self.interpolation) return F.resize_image_tensor(input, self.size, interpolation=self.interpolation)
else: else:
return input return input
...@@ -64,7 +64,7 @@ class CenterCrop(Transform): ...@@ -64,7 +64,7 @@ class CenterCrop(Transform):
if isinstance(input, features.Image): if isinstance(input, features.Image):
output = F.center_crop_image_tensor(input, self.output_size) output = F.center_crop_image_tensor(input, self.output_size)
return features.Image.new_like(input, output) return features.Image.new_like(input, output)
elif isinstance(input, torch.Tensor): elif is_simple_tensor(input):
return F.center_crop_image_tensor(input, self.output_size) return F.center_crop_image_tensor(input, self.output_size)
elif isinstance(input, PIL.Image.Image): elif isinstance(input, PIL.Image.Image):
return F.center_crop_image_pil(input, self.output_size) return F.center_crop_image_pil(input, self.output_size)
...@@ -156,7 +156,7 @@ class RandomResizedCrop(Transform): ...@@ -156,7 +156,7 @@ class RandomResizedCrop(Transform):
input, **params, size=list(self.size), interpolation=self.interpolation input, **params, size=list(self.size), interpolation=self.interpolation
) )
return features.Image.new_like(input, output) return features.Image.new_like(input, output)
elif isinstance(input, torch.Tensor): elif is_simple_tensor(input):
return F.resized_crop_image_tensor(input, **params, size=list(self.size), interpolation=self.interpolation) return F.resized_crop_image_tensor(input, **params, size=list(self.size), interpolation=self.interpolation)
elif isinstance(input, PIL.Image.Image): elif isinstance(input, PIL.Image.Image):
return F.resized_crop_image_pil(input, **params, size=list(self.size), interpolation=self.interpolation) return F.resized_crop_image_pil(input, **params, size=list(self.size), interpolation=self.interpolation)
......
...@@ -6,6 +6,8 @@ from torchvision.prototype import features ...@@ -6,6 +6,8 @@ from torchvision.prototype import features
from torchvision.prototype.transforms import Transform, functional as F from torchvision.prototype.transforms import Transform, functional as F
from torchvision.transforms.functional import convert_image_dtype from torchvision.transforms.functional import convert_image_dtype
from ._utils import is_simple_tensor
class ConvertBoundingBoxFormat(Transform): class ConvertBoundingBoxFormat(Transform):
def __init__(self, format: Union[str, features.BoundingBoxFormat]) -> None: def __init__(self, format: Union[str, features.BoundingBoxFormat]) -> None:
...@@ -15,7 +17,7 @@ class ConvertBoundingBoxFormat(Transform): ...@@ -15,7 +17,7 @@ class ConvertBoundingBoxFormat(Transform):
self.format = format self.format = format
def _transform(self, input: Any, params: Dict[str, Any]) -> Any: def _transform(self, input: Any, params: Dict[str, Any]) -> Any:
if type(input) is features.BoundingBox: if isinstance(input, features.BoundingBox):
output = F.convert_bounding_box_format(input, old_format=input.format, new_format=params["format"]) output = F.convert_bounding_box_format(input, old_format=input.format, new_format=params["format"])
return features.BoundingBox.new_like(input, output, format=params["format"]) return features.BoundingBox.new_like(input, output, format=params["format"])
else: else:
...@@ -28,9 +30,11 @@ class ConvertImageDtype(Transform): ...@@ -28,9 +30,11 @@ class ConvertImageDtype(Transform):
self.dtype = dtype self.dtype = dtype
def _transform(self, input: Any, params: Dict[str, Any]) -> Any: def _transform(self, input: Any, params: Dict[str, Any]) -> Any:
if type(input) is features.Image: if isinstance(input, features.Image):
output = convert_image_dtype(input, dtype=self.dtype) output = convert_image_dtype(input, dtype=self.dtype)
return features.Image.new_like(input, output, dtype=self.dtype) return features.Image.new_like(input, output, dtype=self.dtype)
elif is_simple_tensor(input):
return convert_image_dtype(input, dtype=self.dtype)
else: else:
return input return input
...@@ -57,7 +61,7 @@ class ConvertImageColorSpace(Transform): ...@@ -57,7 +61,7 @@ class ConvertImageColorSpace(Transform):
input, old_color_space=input.color_space, new_color_space=self.color_space input, old_color_space=input.color_space, new_color_space=self.color_space
) )
return features.Image.new_like(input, output, color_space=self.color_space) return features.Image.new_like(input, output, color_space=self.color_space)
elif isinstance(input, torch.Tensor): elif is_simple_tensor(input):
if self.old_color_space is None: if self.old_color_space is None:
raise RuntimeError( raise RuntimeError(
f"In order to convert vanilla tensor images, `{type(self).__name__}(...)` " f"In order to convert vanilla tensor images, `{type(self).__name__}(...)` "
......
...@@ -6,7 +6,7 @@ from torchvision.prototype.transforms import Transform, functional as F ...@@ -6,7 +6,7 @@ from torchvision.prototype.transforms import Transform, functional as F
class DecodeImage(Transform): class DecodeImage(Transform):
def _transform(self, input: Any, params: Dict[str, Any]) -> Any: def _transform(self, input: Any, params: Dict[str, Any]) -> Any:
if type(input) is features.EncodedImage: if isinstance(input, features.EncodedImage):
output = F.decode_image_with_pil(input) output = F.decode_image_with_pil(input)
return features.Image(output) return features.Image(output)
else: else:
...@@ -19,7 +19,7 @@ class LabelToOneHot(Transform): ...@@ -19,7 +19,7 @@ class LabelToOneHot(Transform):
self.num_categories = num_categories self.num_categories = num_categories
def _transform(self, input: Any, params: Dict[str, Any]) -> Any: def _transform(self, input: Any, params: Dict[str, Any]) -> Any:
if type(input) is features.Label: if isinstance(input, features.Label):
num_categories = self.num_categories num_categories = self.num_categories
if num_categories == -1 and input.categories is not None: if num_categories == -1 and input.categories is not None:
num_categories = len(input.categories) num_categories = len(input.categories)
......
...@@ -46,3 +46,7 @@ def has_any(sample: Any, *types: Type) -> bool: ...@@ -46,3 +46,7 @@ def has_any(sample: Any, *types: Type) -> bool:
def has_all(sample: Any, *types: Type) -> bool: def has_all(sample: Any, *types: Type) -> bool:
return not bool(set(types) - set(_extract_types(sample))) return not bool(set(types) - set(_extract_types(sample)))
def is_simple_tensor(input: Any) -> bool:
return isinstance(input, torch.Tensor) and not isinstance(input, features._Feature)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment