Unverified Commit 2e70ee1a authored by vfdev's avatar vfdev Committed by GitHub
Browse files

[proto] Fix for handling numpy arrays by Transform (#6385)

* [proto] Fix for handling Numpy arrays by Transform

* transformed_types -> _transformed_types
parent 1b44be35
import itertools
import numpy as np
import PIL.Image
import pytest
......@@ -991,3 +993,94 @@ class TestRandomErasing:
fn.assert_called_once_with(erase_image_tensor_inpt, **params)
else:
fn.call_count == 0
class TestTransform:
@pytest.mark.parametrize(
"inpt_type",
[torch.Tensor, PIL.Image.Image, features.Image, np.ndarray, features.BoundingBox, str, int],
)
def test_check_transformed_types(self, inpt_type, mocker):
# This test ensures that we correctly handle which types to transform and which to bypass
t = transforms.Transform()
inpt = mocker.MagicMock(spec=inpt_type)
if inpt_type in (np.ndarray, str, int):
output = t(inpt)
assert output is inpt
else:
with pytest.raises(NotImplementedError):
t(inpt)
class TestToImageTensor:
@pytest.mark.parametrize(
"inpt_type",
[torch.Tensor, PIL.Image.Image, features.Image, np.ndarray, features.BoundingBox, str, int],
)
def test__transform(self, inpt_type, mocker):
fn = mocker.patch(
"torchvision.prototype.transforms.functional.to_image_tensor",
return_value=torch.rand(1, 3, 8, 8),
)
inpt = mocker.MagicMock(spec=inpt_type)
transform = transforms.ToImageTensor()
transform(inpt)
if inpt_type in (features.BoundingBox, str, int):
fn.call_count == 0
else:
fn.assert_called_once_with(inpt, copy=transform.copy)
class TestToImagePIL:
@pytest.mark.parametrize(
"inpt_type",
[torch.Tensor, PIL.Image.Image, features.Image, np.ndarray, features.BoundingBox, str, int],
)
def test__transform(self, inpt_type, mocker):
fn = mocker.patch("torchvision.prototype.transforms.functional.to_image_pil")
inpt = mocker.MagicMock(spec=inpt_type)
transform = transforms.ToImagePIL()
transform(inpt)
if inpt_type in (features.BoundingBox, str, int):
fn.call_count == 0
else:
fn.assert_called_once_with(inpt, copy=transform.copy)
class TestToPILImage:
@pytest.mark.parametrize(
"inpt_type",
[torch.Tensor, PIL.Image.Image, features.Image, np.ndarray, features.BoundingBox, str, int],
)
def test__transform(self, inpt_type, mocker):
fn = mocker.patch("torchvision.transforms.functional.to_pil_image")
inpt = mocker.MagicMock(spec=inpt_type)
with pytest.warns(UserWarning, match="deprecated and will be removed"):
transform = transforms.ToPILImage()
transform(inpt)
if inpt_type in (PIL.Image.Image, features.BoundingBox, str, int):
fn.call_count == 0
else:
fn.assert_called_once_with(inpt, mode=transform.mode)
class TestToTensor:
@pytest.mark.parametrize(
"inpt_type",
[torch.Tensor, PIL.Image.Image, features.Image, np.ndarray, features.BoundingBox, str, int],
)
def test__transform(self, inpt_type, mocker):
fn = mocker.patch("torchvision.transforms.functional.to_tensor")
inpt = mocker.MagicMock(spec=inpt_type)
with pytest.warns(UserWarning, match="deprecated and will be removed"):
transform = transforms.ToTensor()
transform(inpt)
if inpt_type in (features.Image, torch.Tensor, features.BoundingBox, str, int):
fn.call_count == 0
else:
fn.assert_called_once_with(inpt)
......@@ -34,6 +34,6 @@ from ._geometry import (
)
from ._meta import ConvertBoundingBoxFormat, ConvertImageColorSpace, ConvertImageDtype
from ._misc import GaussianBlur, Identity, Lambda, Normalize, ToDtype
from ._type_conversion import DecodeImage, LabelToOneHot
from ._type_conversion import DecodeImage, LabelToOneHot, ToImagePIL, ToImageTensor
from ._deprecated import Grayscale, RandomGrayscale, ToTensor, ToPILImage, PILToTensor # usort: skip
......@@ -3,6 +3,7 @@ from typing import Any, Dict, Optional
import numpy as np
import PIL.Image
import torch
from torchvision.prototype import features
from torchvision.prototype.features import ColorSpace
from torchvision.prototype.transforms import Transform
......@@ -15,6 +16,10 @@ from ._utils import is_simple_tensor
class ToTensor(Transform):
# Updated transformed types for ToTensor
_transformed_types = (torch.Tensor, features._Feature, PIL.Image.Image, np.ndarray)
def __init__(self) -> None:
warnings.warn(
"The transform `ToTensor()` is deprecated and will be removed in a future release. "
......@@ -23,8 +28,6 @@ class ToTensor(Transform):
super().__init__()
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
# TODO: Transforms allows to pass only (torch.Tensor, _Feature, PIL.Image.Image)
# so input as np.ndarray is not possible. We need to make it possible
if isinstance(inpt, (PIL.Image.Image, np.ndarray)):
return _F.to_tensor(inpt)
else:
......@@ -47,6 +50,10 @@ class PILToTensor(Transform):
class ToPILImage(Transform):
# Updated transformed types for ToPILImage
_transformed_types = (torch.Tensor, features._Feature, PIL.Image.Image, np.ndarray)
def __init__(self, mode: Optional[str] = None) -> None:
warnings.warn(
"The transform `ToPILImage()` is deprecated and will be removed in a future release. "
......@@ -56,8 +63,6 @@ class ToPILImage(Transform):
self.mode = mode
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
# TODO: Transforms allows to pass only (torch.Tensor, _Feature, PIL.Image.Image)
# so input as np.ndarray is not possible. We need to make it possible
if is_simple_tensor(inpt) or isinstance(inpt, (features.Image, np.ndarray)):
return _F.to_pil_image(inpt, mode=self.mode)
else:
......
import enum
from typing import Any, Dict
from typing import Any, Dict, Tuple, Type
import PIL.Image
import torch
......@@ -10,6 +10,10 @@ from torchvision.utils import _log_api_usage_once
class Transform(nn.Module):
# Class attribute defining transformed types. Other types are passed-through without any transformation
_transformed_types: Tuple[Type, ...] = (torch.Tensor, _Feature, PIL.Image.Image)
def __init__(self) -> None:
super().__init__()
_log_api_usage_once(self)
......@@ -26,9 +30,8 @@ class Transform(nn.Module):
params = self._get_params(sample)
flat_inputs, spec = tree_flatten(sample)
transformed_types = (torch.Tensor, _Feature, PIL.Image.Image)
flat_outputs = [
self._transform(inpt, params) if isinstance(inpt, transformed_types) else inpt for inpt in flat_inputs
self._transform(inpt, params) if isinstance(inpt, self._transformed_types) else inpt for inpt in flat_inputs
]
return tree_unflatten(flat_outputs, spec)
......
......@@ -2,6 +2,8 @@ from typing import Any, Dict
import numpy as np
import PIL.Image
import torch
from torchvision.prototype import features
from torchvision.prototype.transforms import functional as F, Transform
......@@ -40,13 +42,15 @@ class LabelToOneHot(Transform):
class ToImageTensor(Transform):
# Updated transformed types for ToImageTensor
_transformed_types = (torch.Tensor, features._Feature, PIL.Image.Image, np.ndarray)
def __init__(self, *, copy: bool = False) -> None:
super().__init__()
self.copy = copy
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
# TODO: Transforms allows to pass only (torch.Tensor, _Feature, PIL.Image.Image)
# so input as np.ndarray is not possible. We need to make it possible
if isinstance(inpt, (features.Image, PIL.Image.Image, np.ndarray)) or is_simple_tensor(inpt):
output = F.to_image_tensor(inpt, copy=self.copy)
return features.Image(output)
......@@ -55,13 +59,15 @@ class ToImageTensor(Transform):
class ToImagePIL(Transform):
# Updated transformed types for ToImagePIL
_transformed_types = (torch.Tensor, features._Feature, PIL.Image.Image, np.ndarray)
def __init__(self, *, copy: bool = False) -> None:
super().__init__()
self.copy = copy
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
# TODO: Transforms allows to pass only (torch.Tensor, _Feature, PIL.Image.Image)
# so input as np.ndarray is not possible. We need to make it possible
if isinstance(inpt, (features.Image, PIL.Image.Image, np.ndarray)) or is_simple_tensor(inpt):
return F.to_image_pil(inpt, copy=self.copy)
else:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment