Unverified Commit 7de63171 authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

move simple_tensor to features module (#6507)

* move simple_tensor to features module

* fix test
parent 13ea9018
......@@ -6,7 +6,7 @@ import torch
from test_prototype_transforms_functional import make_bounding_box, make_image, make_segmentation_mask
from torchvision.prototype import features
from torchvision.prototype.transforms._utils import has_all, has_any, is_simple_tensor
from torchvision.prototype.transforms._utils import has_all, has_any
from torchvision.prototype.transforms.functional import to_image_pil
......@@ -36,9 +36,9 @@ SEGMENTATION_MASK = make_segmentation_mask(size=IMAGE.image_size)
((IMAGE, BOUNDING_BOX, SEGMENTATION_MASK), (lambda obj: isinstance(obj, features.Image),), True),
((IMAGE, BOUNDING_BOX, SEGMENTATION_MASK), (lambda _: False,), False),
((IMAGE, BOUNDING_BOX, SEGMENTATION_MASK), (lambda _: True,), True),
((IMAGE,), (features.Image, PIL.Image.Image, is_simple_tensor), True),
((torch.Tensor(IMAGE),), (features.Image, PIL.Image.Image, is_simple_tensor), True),
((to_image_pil(IMAGE),), (features.Image, PIL.Image.Image, is_simple_tensor), True),
((IMAGE,), (features.Image, PIL.Image.Image, features.is_simple_tensor), True),
((torch.Tensor(IMAGE),), (features.Image, PIL.Image.Image, features.is_simple_tensor), True),
((to_image_pil(IMAGE),), (features.Image, PIL.Image.Image, features.is_simple_tensor), True),
],
)
def test_has_any(sample, types, expected):
......
from ._bounding_box import BoundingBox, BoundingBoxFormat
from ._encoded import EncodedData, EncodedImage, EncodedVideo
from ._feature import _Feature
from ._feature import _Feature, is_simple_tensor
from ._image import ColorSpace, Image
from ._label import Label, OneHotLabel
from ._segmentation_mask import SegmentationMask
......@@ -10,6 +10,10 @@ from torchvision.transforms import InterpolationMode
F = TypeVar("F", bound="_Feature")
def is_simple_tensor(inpt: Any) -> bool:
return isinstance(inpt, torch.Tensor) and not isinstance(inpt, _Feature)
class _Feature(torch.Tensor):
__F: Optional[ModuleType] = None
......
......@@ -13,7 +13,7 @@ from torchvision.prototype.transforms import functional as F
from torchvision.transforms.functional import InterpolationMode, pil_to_tensor
from ._transform import _RandomApplyTransform
from ._utils import has_any, is_simple_tensor, query_chw
from ._utils import has_any, query_chw
class RandomErasing(_RandomApplyTransform):
......@@ -102,7 +102,7 @@ class _BaseMixupCutmix(_RandomApplyTransform):
self._dist = torch.distributions.Beta(torch.tensor([alpha]), torch.tensor([alpha]))
def forward(self, *inputs: Any) -> Any:
if not (has_any(inputs, features.Image, is_simple_tensor) and has_any(inputs, features.OneHotLabel)):
if not (has_any(inputs, features.Image, features.is_simple_tensor) and has_any(inputs, features.OneHotLabel)):
raise TypeError(f"{type(self).__name__}() is only defined for tensor images and one-hot labels.")
if has_any(inputs, features.BoundingBox, features.SegmentationMask, features.Label):
raise TypeError(
......@@ -124,7 +124,7 @@ class RandomMixup(_BaseMixupCutmix):
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
lam = params["lam"]
if isinstance(inpt, features.Image) or is_simple_tensor(inpt):
if isinstance(inpt, features.Image) or features.is_simple_tensor(inpt):
if inpt.ndim < 4:
raise ValueError("Need a batch of images")
output = inpt.clone()
......@@ -164,7 +164,7 @@ class RandomCutmix(_BaseMixupCutmix):
return dict(box=box, lam_adjusted=lam_adjusted)
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
if isinstance(inpt, features.Image) or is_simple_tensor(inpt):
if isinstance(inpt, features.Image) or features.is_simple_tensor(inpt):
box = params["box"]
if inpt.ndim < 4:
raise ValueError("Need a batch of images")
......@@ -276,7 +276,7 @@ class SimpleCopyPaste(_RandomApplyTransform):
# with List[image], List[BoundingBox], List[SegmentationMask], List[Label]
images, bboxes, masks, labels = [], [], [], []
for obj in flat_sample:
if isinstance(obj, features.Image) or is_simple_tensor(obj):
if isinstance(obj, features.Image) or features.is_simple_tensor(obj):
images.append(obj)
elif isinstance(obj, PIL.Image.Image):
images.append(pil_to_tensor(obj))
......@@ -310,7 +310,7 @@ class SimpleCopyPaste(_RandomApplyTransform):
elif isinstance(obj, PIL.Image.Image):
flat_sample[i] = F.to_image_pil(output_images[c0])
c0 += 1
elif is_simple_tensor(obj):
elif features.is_simple_tensor(obj):
flat_sample[i] = output_images[c0]
c0 += 1
elif isinstance(obj, features.BoundingBox):
......
......@@ -11,7 +11,7 @@ from torchvision.prototype.transforms import functional as F, Transform
from torchvision.transforms.autoaugment import AutoAugmentPolicy
from torchvision.transforms.functional import InterpolationMode, pil_to_tensor, to_pil_image
from ._utils import _isinstance, get_chw, is_simple_tensor
from ._utils import _isinstance, get_chw
K = TypeVar("K")
V = TypeVar("V")
......@@ -44,7 +44,7 @@ class _AutoAugmentBase(Transform):
sample_flat, _ = tree_flatten(sample)
images = []
for id, inpt in enumerate(sample_flat):
if _isinstance(inpt, (features.Image, PIL.Image.Image, is_simple_tensor)):
if _isinstance(inpt, (features.Image, PIL.Image.Image, features.is_simple_tensor)):
images.append((id, inpt))
elif isinstance(inpt, unsupported_types):
raise TypeError(f"Inputs of type {type(inpt).__name__} are not supported by {type(self).__name__}()")
......
......@@ -8,7 +8,7 @@ from torchvision.prototype.transforms import functional as F, Transform
from torchvision.transforms import functional as _F
from ._transform import _RandomApplyTransform
from ._utils import is_simple_tensor, query_chw
from ._utils import query_chw
T = TypeVar("T", features.Image, torch.Tensor, PIL.Image.Image)
......@@ -112,7 +112,7 @@ class RandomPhotometricDistort(Transform):
)
def _permute_channels(self, inpt: Any, *, permutation: torch.Tensor) -> Any:
if not (isinstance(inpt, (features.Image, PIL.Image.Image)) or is_simple_tensor(inpt)):
if not (isinstance(inpt, (features.Image, PIL.Image.Image)) or features.is_simple_tensor(inpt)):
return inpt
image = inpt
......
......@@ -11,7 +11,7 @@ from torchvision.transforms import functional as _F
from typing_extensions import Literal
from ._transform import _RandomApplyTransform
from ._utils import is_simple_tensor, query_chw
from ._utils import query_chw
class ToTensor(Transform):
......@@ -43,7 +43,7 @@ class PILToTensor(Transform):
class ToPILImage(Transform):
_transformed_types = (is_simple_tensor, features.Image, np.ndarray)
_transformed_types = (features.is_simple_tensor, features.Image, np.ndarray)
def __init__(self, mode: Optional[str] = None) -> None:
warnings.warn(
......@@ -58,7 +58,7 @@ class ToPILImage(Transform):
class Grayscale(Transform):
_transformed_types = (features.Image, PIL.Image.Image, is_simple_tensor)
_transformed_types = (features.Image, PIL.Image.Image, features.is_simple_tensor)
def __init__(self, num_output_channels: Literal[1, 3] = 1) -> None:
deprecation_msg = (
......@@ -86,7 +86,7 @@ class Grayscale(Transform):
class RandomGrayscale(_RandomApplyTransform):
_transformed_types = (features.Image, PIL.Image.Image, is_simple_tensor)
_transformed_types = (features.Image, PIL.Image.Image, features.is_simple_tensor)
def __init__(self, p: float = 0.1) -> None:
warnings.warn(
......
......@@ -15,7 +15,7 @@ from torchvision.transforms.transforms import _check_sequence_input, _setup_angl
from typing_extensions import Literal
from ._transform import _RandomApplyTransform
from ._utils import has_all, has_any, is_simple_tensor, query_bounding_box, query_chw
from ._utils import has_all, has_any, query_bounding_box, query_chw
class RandomHorizontalFlip(_RandomApplyTransform):
......@@ -156,7 +156,7 @@ class FiveCrop(Transform):
torch.Size([5])
"""
_transformed_types = (features.Image, PIL.Image.Image, is_simple_tensor)
_transformed_types = (features.Image, PIL.Image.Image, features.is_simple_tensor)
def __init__(self, size: Union[int, Sequence[int]]) -> None:
super().__init__()
......@@ -176,7 +176,7 @@ class TenCrop(Transform):
See :class:`~torchvision.prototype.transforms.FiveCrop` for an example.
"""
_transformed_types = (features.Image, PIL.Image.Image, is_simple_tensor)
_transformed_types = (features.Image, PIL.Image.Image, features.is_simple_tensor)
def __init__(self, size: Union[int, Sequence[int]], vertical_flip: bool = False) -> None:
super().__init__()
......@@ -696,7 +696,7 @@ class RandomIoUCrop(Transform):
def forward(self, *inputs: Any) -> Any:
if not (
has_all(inputs, features.BoundingBox)
and has_any(inputs, PIL.Image.Image, features.Image, is_simple_tensor)
and has_any(inputs, PIL.Image.Image, features.Image, features.is_simple_tensor)
and has_any(inputs, features.Label, features.OneHotLabel)
):
raise TypeError(
......@@ -847,7 +847,7 @@ class FixedSizeCrop(Transform):
return inpt
def forward(self, *inputs: Any) -> Any:
if not has_any(inputs, PIL.Image.Image, features.Image, is_simple_tensor):
if not has_any(inputs, PIL.Image.Image, features.Image, features.is_simple_tensor):
raise TypeError(f"{type(self).__name__}() requires input sample to contain an tensor or PIL image.")
if has_any(inputs, features.BoundingBox) and not has_any(inputs, features.Label, features.OneHotLabel):
......
......@@ -7,8 +7,6 @@ from torchvision.prototype import features
from torchvision.prototype.transforms import functional as F, Transform
from torchvision.transforms.functional import convert_image_dtype
from ._utils import is_simple_tensor
class ConvertBoundingBoxFormat(Transform):
_transformed_types = (features.BoundingBox,)
......@@ -25,7 +23,7 @@ class ConvertBoundingBoxFormat(Transform):
class ConvertImageDtype(Transform):
_transformed_types = (is_simple_tensor, features.Image)
_transformed_types = (features.is_simple_tensor, features.Image)
def __init__(self, dtype: torch.dtype = torch.float32) -> None:
super().__init__()
......@@ -33,11 +31,11 @@ class ConvertImageDtype(Transform):
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
output = convert_image_dtype(inpt, dtype=self.dtype)
return output if is_simple_tensor(inpt) else features.Image.new_like(inpt, output, dtype=self.dtype)
return output if features.is_simple_tensor(inpt) else features.Image.new_like(inpt, output, dtype=self.dtype)
class ConvertColorSpace(Transform):
_transformed_types = (is_simple_tensor, features.Image, PIL.Image.Image)
_transformed_types = (features.is_simple_tensor, features.Image, PIL.Image.Image)
def __init__(
self,
......
......@@ -7,7 +7,7 @@ import torch
from torchvision.ops import remove_small_boxes
from torchvision.prototype import features
from torchvision.prototype.transforms import functional as F, Transform
from torchvision.prototype.transforms._utils import has_any, is_simple_tensor, query_bounding_box
from torchvision.prototype.transforms._utils import has_any, query_bounding_box
from torchvision.transforms.transforms import _setup_size
......@@ -38,7 +38,7 @@ class Lambda(Transform):
class LinearTransformation(Transform):
_transformed_types = (is_simple_tensor, features.Image)
_transformed_types = (features.is_simple_tensor, features.Image)
def __init__(self, transformation_matrix: torch.Tensor, mean_vector: torch.Tensor):
super().__init__()
......@@ -93,7 +93,7 @@ class LinearTransformation(Transform):
class Normalize(Transform):
_transformed_types = (features.Image, is_simple_tensor)
_transformed_types = (features.Image, features.is_simple_tensor)
def __init__(self, mean: Sequence[float], std: Sequence[float]):
super().__init__()
......
......@@ -5,15 +5,19 @@ import PIL.Image
import torch
from torch import nn
from torch.utils._pytree import tree_flatten, tree_unflatten
from torchvision.prototype.features import _Feature
from torchvision.prototype.transforms._utils import _isinstance, is_simple_tensor
from torchvision.prototype import features
from torchvision.prototype.transforms._utils import _isinstance
from torchvision.utils import _log_api_usage_once
class Transform(nn.Module):
# Class attribute defining transformed types. Other types are passed-through without any transformation
_transformed_types: Tuple[Union[Type, Callable[[Any], bool]], ...] = (is_simple_tensor, _Feature, PIL.Image.Image)
_transformed_types: Tuple[Union[Type, Callable[[Any], bool]], ...] = (
features.is_simple_tensor,
features._Feature,
PIL.Image.Image,
)
def __init__(self) -> None:
super().__init__()
......
......@@ -7,8 +7,6 @@ from torch.nn.functional import one_hot
from torchvision.prototype import features
from torchvision.prototype.transforms import functional as F, Transform
from ._utils import is_simple_tensor
class DecodeImage(Transform):
_transformed_types = (features.EncodedImage,)
......@@ -39,14 +37,14 @@ class LabelToOneHot(Transform):
class ToImageTensor(Transform):
_transformed_types = (is_simple_tensor, PIL.Image.Image, np.ndarray)
_transformed_types = (features.is_simple_tensor, PIL.Image.Image, np.ndarray)
def _transform(self, inpt: Any, params: Dict[str, Any]) -> features.Image:
return F.to_image_tensor(inpt)
class ToImagePIL(Transform):
_transformed_types = (is_simple_tensor, features.Image, np.ndarray)
_transformed_types = (features.is_simple_tensor, features.Image, np.ndarray)
def __init__(self, *, mode: Optional[str] = None) -> None:
super().__init__()
......
......@@ -23,7 +23,7 @@ def get_chw(image: Union[PIL.Image.Image, torch.Tensor, features.Image]) -> Tupl
if isinstance(image, features.Image):
channels = image.num_channels
height, width = image.image_size
elif is_simple_tensor(image):
elif features.is_simple_tensor(image):
channels, height, width = get_dimensions_image_tensor(image)
elif isinstance(image, PIL.Image.Image):
channels, height, width = get_dimensions_image_pil(image)
......@@ -37,7 +37,7 @@ def query_chw(sample: Any) -> Tuple[int, int, int]:
chws = {
get_chw(item)
for item in flat_sample
if isinstance(item, (features.Image, PIL.Image.Image)) or is_simple_tensor(item)
if isinstance(item, (features.Image, PIL.Image.Image)) or features.is_simple_tensor(item)
}
if not chws:
raise TypeError("No image was found in the sample")
......@@ -70,10 +70,3 @@ def has_all(sample: Any, *types_or_checks: Union[Type, Callable[[Any], bool]]) -
else:
return False
return True
# TODO: Given that this is not related to pytree / the Transform object, we should probably move it to somewhere else.
# One possibility is `functional._utils` so both the functionals and the transforms have proper access to it. We could
# also move it `features` since it literally checks for the _Feature type.
def is_simple_tensor(inpt: Any) -> bool:
return isinstance(inpt, torch.Tensor) and not isinstance(inpt, features._Feature)
......@@ -6,8 +6,6 @@ import PIL.Image
from torchvision.prototype import features
from torchvision.transforms import functional as _F
from .._utils import is_simple_tensor
def to_grayscale(inpt: PIL.Image.Image, num_output_channels: int = 1) -> PIL.Image.Image:
call = ", num_output_channels=3" if num_output_channels == 3 else ""
......@@ -23,7 +21,7 @@ def to_grayscale(inpt: PIL.Image.Image, num_output_channels: int = 1) -> PIL.Ima
def rgb_to_grayscale(inpt: Any, num_output_channels: int = 1) -> Any:
old_color_space = features.Image.guess_color_space(inpt) if is_simple_tensor(inpt) else None
old_color_space = features.Image.guess_color_space(inpt) if features.is_simple_tensor(inpt) else None
call = ", num_output_channels=3" if num_output_channels == 3 else ""
replacement = (
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment