Unverified Commit 7de63171 authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

move simple_tensor to features module (#6507)

* move simple_tensor to features module

* fix test
parent 13ea9018
...@@ -6,7 +6,7 @@ import torch ...@@ -6,7 +6,7 @@ import torch
from test_prototype_transforms_functional import make_bounding_box, make_image, make_segmentation_mask from test_prototype_transforms_functional import make_bounding_box, make_image, make_segmentation_mask
from torchvision.prototype import features from torchvision.prototype import features
from torchvision.prototype.transforms._utils import has_all, has_any, is_simple_tensor from torchvision.prototype.transforms._utils import has_all, has_any
from torchvision.prototype.transforms.functional import to_image_pil from torchvision.prototype.transforms.functional import to_image_pil
...@@ -36,9 +36,9 @@ SEGMENTATION_MASK = make_segmentation_mask(size=IMAGE.image_size) ...@@ -36,9 +36,9 @@ SEGMENTATION_MASK = make_segmentation_mask(size=IMAGE.image_size)
((IMAGE, BOUNDING_BOX, SEGMENTATION_MASK), (lambda obj: isinstance(obj, features.Image),), True), ((IMAGE, BOUNDING_BOX, SEGMENTATION_MASK), (lambda obj: isinstance(obj, features.Image),), True),
((IMAGE, BOUNDING_BOX, SEGMENTATION_MASK), (lambda _: False,), False), ((IMAGE, BOUNDING_BOX, SEGMENTATION_MASK), (lambda _: False,), False),
((IMAGE, BOUNDING_BOX, SEGMENTATION_MASK), (lambda _: True,), True), ((IMAGE, BOUNDING_BOX, SEGMENTATION_MASK), (lambda _: True,), True),
((IMAGE,), (features.Image, PIL.Image.Image, is_simple_tensor), True), ((IMAGE,), (features.Image, PIL.Image.Image, features.is_simple_tensor), True),
((torch.Tensor(IMAGE),), (features.Image, PIL.Image.Image, is_simple_tensor), True), ((torch.Tensor(IMAGE),), (features.Image, PIL.Image.Image, features.is_simple_tensor), True),
((to_image_pil(IMAGE),), (features.Image, PIL.Image.Image, is_simple_tensor), True), ((to_image_pil(IMAGE),), (features.Image, PIL.Image.Image, features.is_simple_tensor), True),
], ],
) )
def test_has_any(sample, types, expected): def test_has_any(sample, types, expected):
......
from ._bounding_box import BoundingBox, BoundingBoxFormat from ._bounding_box import BoundingBox, BoundingBoxFormat
from ._encoded import EncodedData, EncodedImage, EncodedVideo from ._encoded import EncodedData, EncodedImage, EncodedVideo
from ._feature import _Feature from ._feature import _Feature, is_simple_tensor
from ._image import ColorSpace, Image from ._image import ColorSpace, Image
from ._label import Label, OneHotLabel from ._label import Label, OneHotLabel
from ._segmentation_mask import SegmentationMask from ._segmentation_mask import SegmentationMask
...@@ -10,6 +10,10 @@ from torchvision.transforms import InterpolationMode ...@@ -10,6 +10,10 @@ from torchvision.transforms import InterpolationMode
F = TypeVar("F", bound="_Feature") F = TypeVar("F", bound="_Feature")
def is_simple_tensor(inpt: Any) -> bool:
return isinstance(inpt, torch.Tensor) and not isinstance(inpt, _Feature)
class _Feature(torch.Tensor): class _Feature(torch.Tensor):
__F: Optional[ModuleType] = None __F: Optional[ModuleType] = None
......
...@@ -13,7 +13,7 @@ from torchvision.prototype.transforms import functional as F ...@@ -13,7 +13,7 @@ from torchvision.prototype.transforms import functional as F
from torchvision.transforms.functional import InterpolationMode, pil_to_tensor from torchvision.transforms.functional import InterpolationMode, pil_to_tensor
from ._transform import _RandomApplyTransform from ._transform import _RandomApplyTransform
from ._utils import has_any, is_simple_tensor, query_chw from ._utils import has_any, query_chw
class RandomErasing(_RandomApplyTransform): class RandomErasing(_RandomApplyTransform):
...@@ -102,7 +102,7 @@ class _BaseMixupCutmix(_RandomApplyTransform): ...@@ -102,7 +102,7 @@ class _BaseMixupCutmix(_RandomApplyTransform):
self._dist = torch.distributions.Beta(torch.tensor([alpha]), torch.tensor([alpha])) self._dist = torch.distributions.Beta(torch.tensor([alpha]), torch.tensor([alpha]))
def forward(self, *inputs: Any) -> Any: def forward(self, *inputs: Any) -> Any:
if not (has_any(inputs, features.Image, is_simple_tensor) and has_any(inputs, features.OneHotLabel)): if not (has_any(inputs, features.Image, features.is_simple_tensor) and has_any(inputs, features.OneHotLabel)):
raise TypeError(f"{type(self).__name__}() is only defined for tensor images and one-hot labels.") raise TypeError(f"{type(self).__name__}() is only defined for tensor images and one-hot labels.")
if has_any(inputs, features.BoundingBox, features.SegmentationMask, features.Label): if has_any(inputs, features.BoundingBox, features.SegmentationMask, features.Label):
raise TypeError( raise TypeError(
...@@ -124,7 +124,7 @@ class RandomMixup(_BaseMixupCutmix): ...@@ -124,7 +124,7 @@ class RandomMixup(_BaseMixupCutmix):
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
lam = params["lam"] lam = params["lam"]
if isinstance(inpt, features.Image) or is_simple_tensor(inpt): if isinstance(inpt, features.Image) or features.is_simple_tensor(inpt):
if inpt.ndim < 4: if inpt.ndim < 4:
raise ValueError("Need a batch of images") raise ValueError("Need a batch of images")
output = inpt.clone() output = inpt.clone()
...@@ -164,7 +164,7 @@ class RandomCutmix(_BaseMixupCutmix): ...@@ -164,7 +164,7 @@ class RandomCutmix(_BaseMixupCutmix):
return dict(box=box, lam_adjusted=lam_adjusted) return dict(box=box, lam_adjusted=lam_adjusted)
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
if isinstance(inpt, features.Image) or is_simple_tensor(inpt): if isinstance(inpt, features.Image) or features.is_simple_tensor(inpt):
box = params["box"] box = params["box"]
if inpt.ndim < 4: if inpt.ndim < 4:
raise ValueError("Need a batch of images") raise ValueError("Need a batch of images")
...@@ -276,7 +276,7 @@ class SimpleCopyPaste(_RandomApplyTransform): ...@@ -276,7 +276,7 @@ class SimpleCopyPaste(_RandomApplyTransform):
# with List[image], List[BoundingBox], List[SegmentationMask], List[Label] # with List[image], List[BoundingBox], List[SegmentationMask], List[Label]
images, bboxes, masks, labels = [], [], [], [] images, bboxes, masks, labels = [], [], [], []
for obj in flat_sample: for obj in flat_sample:
if isinstance(obj, features.Image) or is_simple_tensor(obj): if isinstance(obj, features.Image) or features.is_simple_tensor(obj):
images.append(obj) images.append(obj)
elif isinstance(obj, PIL.Image.Image): elif isinstance(obj, PIL.Image.Image):
images.append(pil_to_tensor(obj)) images.append(pil_to_tensor(obj))
...@@ -310,7 +310,7 @@ class SimpleCopyPaste(_RandomApplyTransform): ...@@ -310,7 +310,7 @@ class SimpleCopyPaste(_RandomApplyTransform):
elif isinstance(obj, PIL.Image.Image): elif isinstance(obj, PIL.Image.Image):
flat_sample[i] = F.to_image_pil(output_images[c0]) flat_sample[i] = F.to_image_pil(output_images[c0])
c0 += 1 c0 += 1
elif is_simple_tensor(obj): elif features.is_simple_tensor(obj):
flat_sample[i] = output_images[c0] flat_sample[i] = output_images[c0]
c0 += 1 c0 += 1
elif isinstance(obj, features.BoundingBox): elif isinstance(obj, features.BoundingBox):
......
...@@ -11,7 +11,7 @@ from torchvision.prototype.transforms import functional as F, Transform ...@@ -11,7 +11,7 @@ from torchvision.prototype.transforms import functional as F, Transform
from torchvision.transforms.autoaugment import AutoAugmentPolicy from torchvision.transforms.autoaugment import AutoAugmentPolicy
from torchvision.transforms.functional import InterpolationMode, pil_to_tensor, to_pil_image from torchvision.transforms.functional import InterpolationMode, pil_to_tensor, to_pil_image
from ._utils import _isinstance, get_chw, is_simple_tensor from ._utils import _isinstance, get_chw
K = TypeVar("K") K = TypeVar("K")
V = TypeVar("V") V = TypeVar("V")
...@@ -44,7 +44,7 @@ class _AutoAugmentBase(Transform): ...@@ -44,7 +44,7 @@ class _AutoAugmentBase(Transform):
sample_flat, _ = tree_flatten(sample) sample_flat, _ = tree_flatten(sample)
images = [] images = []
for id, inpt in enumerate(sample_flat): for id, inpt in enumerate(sample_flat):
if _isinstance(inpt, (features.Image, PIL.Image.Image, is_simple_tensor)): if _isinstance(inpt, (features.Image, PIL.Image.Image, features.is_simple_tensor)):
images.append((id, inpt)) images.append((id, inpt))
elif isinstance(inpt, unsupported_types): elif isinstance(inpt, unsupported_types):
raise TypeError(f"Inputs of type {type(inpt).__name__} are not supported by {type(self).__name__}()") raise TypeError(f"Inputs of type {type(inpt).__name__} are not supported by {type(self).__name__}()")
......
...@@ -8,7 +8,7 @@ from torchvision.prototype.transforms import functional as F, Transform ...@@ -8,7 +8,7 @@ from torchvision.prototype.transforms import functional as F, Transform
from torchvision.transforms import functional as _F from torchvision.transforms import functional as _F
from ._transform import _RandomApplyTransform from ._transform import _RandomApplyTransform
from ._utils import is_simple_tensor, query_chw from ._utils import query_chw
T = TypeVar("T", features.Image, torch.Tensor, PIL.Image.Image) T = TypeVar("T", features.Image, torch.Tensor, PIL.Image.Image)
...@@ -112,7 +112,7 @@ class RandomPhotometricDistort(Transform): ...@@ -112,7 +112,7 @@ class RandomPhotometricDistort(Transform):
) )
def _permute_channels(self, inpt: Any, *, permutation: torch.Tensor) -> Any: def _permute_channels(self, inpt: Any, *, permutation: torch.Tensor) -> Any:
if not (isinstance(inpt, (features.Image, PIL.Image.Image)) or is_simple_tensor(inpt)): if not (isinstance(inpt, (features.Image, PIL.Image.Image)) or features.is_simple_tensor(inpt)):
return inpt return inpt
image = inpt image = inpt
......
...@@ -11,7 +11,7 @@ from torchvision.transforms import functional as _F ...@@ -11,7 +11,7 @@ from torchvision.transforms import functional as _F
from typing_extensions import Literal from typing_extensions import Literal
from ._transform import _RandomApplyTransform from ._transform import _RandomApplyTransform
from ._utils import is_simple_tensor, query_chw from ._utils import query_chw
class ToTensor(Transform): class ToTensor(Transform):
...@@ -43,7 +43,7 @@ class PILToTensor(Transform): ...@@ -43,7 +43,7 @@ class PILToTensor(Transform):
class ToPILImage(Transform): class ToPILImage(Transform):
_transformed_types = (is_simple_tensor, features.Image, np.ndarray) _transformed_types = (features.is_simple_tensor, features.Image, np.ndarray)
def __init__(self, mode: Optional[str] = None) -> None: def __init__(self, mode: Optional[str] = None) -> None:
warnings.warn( warnings.warn(
...@@ -58,7 +58,7 @@ class ToPILImage(Transform): ...@@ -58,7 +58,7 @@ class ToPILImage(Transform):
class Grayscale(Transform): class Grayscale(Transform):
_transformed_types = (features.Image, PIL.Image.Image, is_simple_tensor) _transformed_types = (features.Image, PIL.Image.Image, features.is_simple_tensor)
def __init__(self, num_output_channels: Literal[1, 3] = 1) -> None: def __init__(self, num_output_channels: Literal[1, 3] = 1) -> None:
deprecation_msg = ( deprecation_msg = (
...@@ -86,7 +86,7 @@ class Grayscale(Transform): ...@@ -86,7 +86,7 @@ class Grayscale(Transform):
class RandomGrayscale(_RandomApplyTransform): class RandomGrayscale(_RandomApplyTransform):
_transformed_types = (features.Image, PIL.Image.Image, is_simple_tensor) _transformed_types = (features.Image, PIL.Image.Image, features.is_simple_tensor)
def __init__(self, p: float = 0.1) -> None: def __init__(self, p: float = 0.1) -> None:
warnings.warn( warnings.warn(
......
...@@ -15,7 +15,7 @@ from torchvision.transforms.transforms import _check_sequence_input, _setup_angl ...@@ -15,7 +15,7 @@ from torchvision.transforms.transforms import _check_sequence_input, _setup_angl
from typing_extensions import Literal from typing_extensions import Literal
from ._transform import _RandomApplyTransform from ._transform import _RandomApplyTransform
from ._utils import has_all, has_any, is_simple_tensor, query_bounding_box, query_chw from ._utils import has_all, has_any, query_bounding_box, query_chw
class RandomHorizontalFlip(_RandomApplyTransform): class RandomHorizontalFlip(_RandomApplyTransform):
...@@ -156,7 +156,7 @@ class FiveCrop(Transform): ...@@ -156,7 +156,7 @@ class FiveCrop(Transform):
torch.Size([5]) torch.Size([5])
""" """
_transformed_types = (features.Image, PIL.Image.Image, is_simple_tensor) _transformed_types = (features.Image, PIL.Image.Image, features.is_simple_tensor)
def __init__(self, size: Union[int, Sequence[int]]) -> None: def __init__(self, size: Union[int, Sequence[int]]) -> None:
super().__init__() super().__init__()
...@@ -176,7 +176,7 @@ class TenCrop(Transform): ...@@ -176,7 +176,7 @@ class TenCrop(Transform):
See :class:`~torchvision.prototype.transforms.FiveCrop` for an example. See :class:`~torchvision.prototype.transforms.FiveCrop` for an example.
""" """
_transformed_types = (features.Image, PIL.Image.Image, is_simple_tensor) _transformed_types = (features.Image, PIL.Image.Image, features.is_simple_tensor)
def __init__(self, size: Union[int, Sequence[int]], vertical_flip: bool = False) -> None: def __init__(self, size: Union[int, Sequence[int]], vertical_flip: bool = False) -> None:
super().__init__() super().__init__()
...@@ -696,7 +696,7 @@ class RandomIoUCrop(Transform): ...@@ -696,7 +696,7 @@ class RandomIoUCrop(Transform):
def forward(self, *inputs: Any) -> Any: def forward(self, *inputs: Any) -> Any:
if not ( if not (
has_all(inputs, features.BoundingBox) has_all(inputs, features.BoundingBox)
and has_any(inputs, PIL.Image.Image, features.Image, is_simple_tensor) and has_any(inputs, PIL.Image.Image, features.Image, features.is_simple_tensor)
and has_any(inputs, features.Label, features.OneHotLabel) and has_any(inputs, features.Label, features.OneHotLabel)
): ):
raise TypeError( raise TypeError(
...@@ -847,7 +847,7 @@ class FixedSizeCrop(Transform): ...@@ -847,7 +847,7 @@ class FixedSizeCrop(Transform):
return inpt return inpt
def forward(self, *inputs: Any) -> Any: def forward(self, *inputs: Any) -> Any:
if not has_any(inputs, PIL.Image.Image, features.Image, is_simple_tensor): if not has_any(inputs, PIL.Image.Image, features.Image, features.is_simple_tensor):
raise TypeError(f"{type(self).__name__}() requires input sample to contain an tensor or PIL image.") raise TypeError(f"{type(self).__name__}() requires input sample to contain an tensor or PIL image.")
if has_any(inputs, features.BoundingBox) and not has_any(inputs, features.Label, features.OneHotLabel): if has_any(inputs, features.BoundingBox) and not has_any(inputs, features.Label, features.OneHotLabel):
......
...@@ -7,8 +7,6 @@ from torchvision.prototype import features ...@@ -7,8 +7,6 @@ from torchvision.prototype import features
from torchvision.prototype.transforms import functional as F, Transform from torchvision.prototype.transforms import functional as F, Transform
from torchvision.transforms.functional import convert_image_dtype from torchvision.transforms.functional import convert_image_dtype
from ._utils import is_simple_tensor
class ConvertBoundingBoxFormat(Transform): class ConvertBoundingBoxFormat(Transform):
_transformed_types = (features.BoundingBox,) _transformed_types = (features.BoundingBox,)
...@@ -25,7 +23,7 @@ class ConvertBoundingBoxFormat(Transform): ...@@ -25,7 +23,7 @@ class ConvertBoundingBoxFormat(Transform):
class ConvertImageDtype(Transform): class ConvertImageDtype(Transform):
_transformed_types = (is_simple_tensor, features.Image) _transformed_types = (features.is_simple_tensor, features.Image)
def __init__(self, dtype: torch.dtype = torch.float32) -> None: def __init__(self, dtype: torch.dtype = torch.float32) -> None:
super().__init__() super().__init__()
...@@ -33,11 +31,11 @@ class ConvertImageDtype(Transform): ...@@ -33,11 +31,11 @@ class ConvertImageDtype(Transform):
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
output = convert_image_dtype(inpt, dtype=self.dtype) output = convert_image_dtype(inpt, dtype=self.dtype)
return output if is_simple_tensor(inpt) else features.Image.new_like(inpt, output, dtype=self.dtype) return output if features.is_simple_tensor(inpt) else features.Image.new_like(inpt, output, dtype=self.dtype)
class ConvertColorSpace(Transform): class ConvertColorSpace(Transform):
_transformed_types = (is_simple_tensor, features.Image, PIL.Image.Image) _transformed_types = (features.is_simple_tensor, features.Image, PIL.Image.Image)
def __init__( def __init__(
self, self,
......
...@@ -7,7 +7,7 @@ import torch ...@@ -7,7 +7,7 @@ import torch
from torchvision.ops import remove_small_boxes from torchvision.ops import remove_small_boxes
from torchvision.prototype import features from torchvision.prototype import features
from torchvision.prototype.transforms import functional as F, Transform from torchvision.prototype.transforms import functional as F, Transform
from torchvision.prototype.transforms._utils import has_any, is_simple_tensor, query_bounding_box from torchvision.prototype.transforms._utils import has_any, query_bounding_box
from torchvision.transforms.transforms import _setup_size from torchvision.transforms.transforms import _setup_size
...@@ -38,7 +38,7 @@ class Lambda(Transform): ...@@ -38,7 +38,7 @@ class Lambda(Transform):
class LinearTransformation(Transform): class LinearTransformation(Transform):
_transformed_types = (is_simple_tensor, features.Image) _transformed_types = (features.is_simple_tensor, features.Image)
def __init__(self, transformation_matrix: torch.Tensor, mean_vector: torch.Tensor): def __init__(self, transformation_matrix: torch.Tensor, mean_vector: torch.Tensor):
super().__init__() super().__init__()
...@@ -93,7 +93,7 @@ class LinearTransformation(Transform): ...@@ -93,7 +93,7 @@ class LinearTransformation(Transform):
class Normalize(Transform): class Normalize(Transform):
_transformed_types = (features.Image, is_simple_tensor) _transformed_types = (features.Image, features.is_simple_tensor)
def __init__(self, mean: Sequence[float], std: Sequence[float]): def __init__(self, mean: Sequence[float], std: Sequence[float]):
super().__init__() super().__init__()
......
...@@ -5,15 +5,19 @@ import PIL.Image ...@@ -5,15 +5,19 @@ import PIL.Image
import torch import torch
from torch import nn from torch import nn
from torch.utils._pytree import tree_flatten, tree_unflatten from torch.utils._pytree import tree_flatten, tree_unflatten
from torchvision.prototype.features import _Feature from torchvision.prototype import features
from torchvision.prototype.transforms._utils import _isinstance, is_simple_tensor from torchvision.prototype.transforms._utils import _isinstance
from torchvision.utils import _log_api_usage_once from torchvision.utils import _log_api_usage_once
class Transform(nn.Module): class Transform(nn.Module):
# Class attribute defining transformed types. Other types are passed-through without any transformation # Class attribute defining transformed types. Other types are passed-through without any transformation
_transformed_types: Tuple[Union[Type, Callable[[Any], bool]], ...] = (is_simple_tensor, _Feature, PIL.Image.Image) _transformed_types: Tuple[Union[Type, Callable[[Any], bool]], ...] = (
features.is_simple_tensor,
features._Feature,
PIL.Image.Image,
)
def __init__(self) -> None: def __init__(self) -> None:
super().__init__() super().__init__()
......
...@@ -7,8 +7,6 @@ from torch.nn.functional import one_hot ...@@ -7,8 +7,6 @@ from torch.nn.functional import one_hot
from torchvision.prototype import features from torchvision.prototype import features
from torchvision.prototype.transforms import functional as F, Transform from torchvision.prototype.transforms import functional as F, Transform
from ._utils import is_simple_tensor
class DecodeImage(Transform): class DecodeImage(Transform):
_transformed_types = (features.EncodedImage,) _transformed_types = (features.EncodedImage,)
...@@ -39,14 +37,14 @@ class LabelToOneHot(Transform): ...@@ -39,14 +37,14 @@ class LabelToOneHot(Transform):
class ToImageTensor(Transform): class ToImageTensor(Transform):
_transformed_types = (is_simple_tensor, PIL.Image.Image, np.ndarray) _transformed_types = (features.is_simple_tensor, PIL.Image.Image, np.ndarray)
def _transform(self, inpt: Any, params: Dict[str, Any]) -> features.Image: def _transform(self, inpt: Any, params: Dict[str, Any]) -> features.Image:
return F.to_image_tensor(inpt) return F.to_image_tensor(inpt)
class ToImagePIL(Transform): class ToImagePIL(Transform):
_transformed_types = (is_simple_tensor, features.Image, np.ndarray) _transformed_types = (features.is_simple_tensor, features.Image, np.ndarray)
def __init__(self, *, mode: Optional[str] = None) -> None: def __init__(self, *, mode: Optional[str] = None) -> None:
super().__init__() super().__init__()
......
...@@ -23,7 +23,7 @@ def get_chw(image: Union[PIL.Image.Image, torch.Tensor, features.Image]) -> Tupl ...@@ -23,7 +23,7 @@ def get_chw(image: Union[PIL.Image.Image, torch.Tensor, features.Image]) -> Tupl
if isinstance(image, features.Image): if isinstance(image, features.Image):
channels = image.num_channels channels = image.num_channels
height, width = image.image_size height, width = image.image_size
elif is_simple_tensor(image): elif features.is_simple_tensor(image):
channels, height, width = get_dimensions_image_tensor(image) channels, height, width = get_dimensions_image_tensor(image)
elif isinstance(image, PIL.Image.Image): elif isinstance(image, PIL.Image.Image):
channels, height, width = get_dimensions_image_pil(image) channels, height, width = get_dimensions_image_pil(image)
...@@ -37,7 +37,7 @@ def query_chw(sample: Any) -> Tuple[int, int, int]: ...@@ -37,7 +37,7 @@ def query_chw(sample: Any) -> Tuple[int, int, int]:
chws = { chws = {
get_chw(item) get_chw(item)
for item in flat_sample for item in flat_sample
if isinstance(item, (features.Image, PIL.Image.Image)) or is_simple_tensor(item) if isinstance(item, (features.Image, PIL.Image.Image)) or features.is_simple_tensor(item)
} }
if not chws: if not chws:
raise TypeError("No image was found in the sample") raise TypeError("No image was found in the sample")
...@@ -70,10 +70,3 @@ def has_all(sample: Any, *types_or_checks: Union[Type, Callable[[Any], bool]]) - ...@@ -70,10 +70,3 @@ def has_all(sample: Any, *types_or_checks: Union[Type, Callable[[Any], bool]]) -
else: else:
return False return False
return True return True
# TODO: Given that this is not related to pytree / the Transform object, we should probably move it to somewhere else.
# One possibility is `functional._utils` so both the functionals and the transforms have proper access to it. We could
# also move it `features` since it literally checks for the _Feature type.
def is_simple_tensor(inpt: Any) -> bool:
return isinstance(inpt, torch.Tensor) and not isinstance(inpt, features._Feature)
...@@ -6,8 +6,6 @@ import PIL.Image ...@@ -6,8 +6,6 @@ import PIL.Image
from torchvision.prototype import features from torchvision.prototype import features
from torchvision.transforms import functional as _F from torchvision.transforms import functional as _F
from .._utils import is_simple_tensor
def to_grayscale(inpt: PIL.Image.Image, num_output_channels: int = 1) -> PIL.Image.Image: def to_grayscale(inpt: PIL.Image.Image, num_output_channels: int = 1) -> PIL.Image.Image:
call = ", num_output_channels=3" if num_output_channels == 3 else "" call = ", num_output_channels=3" if num_output_channels == 3 else ""
...@@ -23,7 +21,7 @@ def to_grayscale(inpt: PIL.Image.Image, num_output_channels: int = 1) -> PIL.Ima ...@@ -23,7 +21,7 @@ def to_grayscale(inpt: PIL.Image.Image, num_output_channels: int = 1) -> PIL.Ima
def rgb_to_grayscale(inpt: Any, num_output_channels: int = 1) -> Any: def rgb_to_grayscale(inpt: Any, num_output_channels: int = 1) -> Any:
old_color_space = features.Image.guess_color_space(inpt) if is_simple_tensor(inpt) else None old_color_space = features.Image.guess_color_space(inpt) if features.is_simple_tensor(inpt) else None
call = ", num_output_channels=3" if num_output_channels == 3 else "" call = ", num_output_channels=3" if num_output_channels == 3 else ""
replacement = ( replacement = (
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment