"...git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "fd5c3c09afe29cbbde42cfa791ca598a07d94da9"
Unverified Commit 2e70ee1a authored by vfdev's avatar vfdev Committed by GitHub
Browse files

[proto] Fix for handling numpy arrays by Transform (#6385)

* [proto] Fix for handling Numpy arrays by Transform

* transformed_types -> _transformed_types
parent 1b44be35
import itertools import itertools
import numpy as np
import PIL.Image import PIL.Image
import pytest import pytest
...@@ -991,3 +993,94 @@ class TestRandomErasing: ...@@ -991,3 +993,94 @@ class TestRandomErasing:
fn.assert_called_once_with(erase_image_tensor_inpt, **params) fn.assert_called_once_with(erase_image_tensor_inpt, **params)
else: else:
fn.call_count == 0 fn.call_count == 0
class TestTransform:
@pytest.mark.parametrize(
"inpt_type",
[torch.Tensor, PIL.Image.Image, features.Image, np.ndarray, features.BoundingBox, str, int],
)
def test_check_transformed_types(self, inpt_type, mocker):
# This test ensures that we correctly handle which types to transform and which to bypass
t = transforms.Transform()
inpt = mocker.MagicMock(spec=inpt_type)
if inpt_type in (np.ndarray, str, int):
output = t(inpt)
assert output is inpt
else:
with pytest.raises(NotImplementedError):
t(inpt)
class TestToImageTensor:
@pytest.mark.parametrize(
"inpt_type",
[torch.Tensor, PIL.Image.Image, features.Image, np.ndarray, features.BoundingBox, str, int],
)
def test__transform(self, inpt_type, mocker):
fn = mocker.patch(
"torchvision.prototype.transforms.functional.to_image_tensor",
return_value=torch.rand(1, 3, 8, 8),
)
inpt = mocker.MagicMock(spec=inpt_type)
transform = transforms.ToImageTensor()
transform(inpt)
if inpt_type in (features.BoundingBox, str, int):
fn.call_count == 0
else:
fn.assert_called_once_with(inpt, copy=transform.copy)
class TestToImagePIL:
@pytest.mark.parametrize(
"inpt_type",
[torch.Tensor, PIL.Image.Image, features.Image, np.ndarray, features.BoundingBox, str, int],
)
def test__transform(self, inpt_type, mocker):
fn = mocker.patch("torchvision.prototype.transforms.functional.to_image_pil")
inpt = mocker.MagicMock(spec=inpt_type)
transform = transforms.ToImagePIL()
transform(inpt)
if inpt_type in (features.BoundingBox, str, int):
fn.call_count == 0
else:
fn.assert_called_once_with(inpt, copy=transform.copy)
class TestToPILImage:
@pytest.mark.parametrize(
"inpt_type",
[torch.Tensor, PIL.Image.Image, features.Image, np.ndarray, features.BoundingBox, str, int],
)
def test__transform(self, inpt_type, mocker):
fn = mocker.patch("torchvision.transforms.functional.to_pil_image")
inpt = mocker.MagicMock(spec=inpt_type)
with pytest.warns(UserWarning, match="deprecated and will be removed"):
transform = transforms.ToPILImage()
transform(inpt)
if inpt_type in (PIL.Image.Image, features.BoundingBox, str, int):
fn.call_count == 0
else:
fn.assert_called_once_with(inpt, mode=transform.mode)
class TestToTensor:
@pytest.mark.parametrize(
"inpt_type",
[torch.Tensor, PIL.Image.Image, features.Image, np.ndarray, features.BoundingBox, str, int],
)
def test__transform(self, inpt_type, mocker):
fn = mocker.patch("torchvision.transforms.functional.to_tensor")
inpt = mocker.MagicMock(spec=inpt_type)
with pytest.warns(UserWarning, match="deprecated and will be removed"):
transform = transforms.ToTensor()
transform(inpt)
if inpt_type in (features.Image, torch.Tensor, features.BoundingBox, str, int):
fn.call_count == 0
else:
fn.assert_called_once_with(inpt)
...@@ -34,6 +34,6 @@ from ._geometry import ( ...@@ -34,6 +34,6 @@ from ._geometry import (
) )
from ._meta import ConvertBoundingBoxFormat, ConvertImageColorSpace, ConvertImageDtype from ._meta import ConvertBoundingBoxFormat, ConvertImageColorSpace, ConvertImageDtype
from ._misc import GaussianBlur, Identity, Lambda, Normalize, ToDtype from ._misc import GaussianBlur, Identity, Lambda, Normalize, ToDtype
from ._type_conversion import DecodeImage, LabelToOneHot from ._type_conversion import DecodeImage, LabelToOneHot, ToImagePIL, ToImageTensor
from ._deprecated import Grayscale, RandomGrayscale, ToTensor, ToPILImage, PILToTensor # usort: skip from ._deprecated import Grayscale, RandomGrayscale, ToTensor, ToPILImage, PILToTensor # usort: skip
...@@ -3,6 +3,7 @@ from typing import Any, Dict, Optional ...@@ -3,6 +3,7 @@ from typing import Any, Dict, Optional
import numpy as np import numpy as np
import PIL.Image import PIL.Image
import torch
from torchvision.prototype import features from torchvision.prototype import features
from torchvision.prototype.features import ColorSpace from torchvision.prototype.features import ColorSpace
from torchvision.prototype.transforms import Transform from torchvision.prototype.transforms import Transform
...@@ -15,6 +16,10 @@ from ._utils import is_simple_tensor ...@@ -15,6 +16,10 @@ from ._utils import is_simple_tensor
class ToTensor(Transform): class ToTensor(Transform):
# Updated transformed types for ToTensor
_transformed_types = (torch.Tensor, features._Feature, PIL.Image.Image, np.ndarray)
def __init__(self) -> None: def __init__(self) -> None:
warnings.warn( warnings.warn(
"The transform `ToTensor()` is deprecated and will be removed in a future release. " "The transform `ToTensor()` is deprecated and will be removed in a future release. "
...@@ -23,8 +28,6 @@ class ToTensor(Transform): ...@@ -23,8 +28,6 @@ class ToTensor(Transform):
super().__init__() super().__init__()
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
# TODO: Transforms allows to pass only (torch.Tensor, _Feature, PIL.Image.Image)
# so input as np.ndarray is not possible. We need to make it possible
if isinstance(inpt, (PIL.Image.Image, np.ndarray)): if isinstance(inpt, (PIL.Image.Image, np.ndarray)):
return _F.to_tensor(inpt) return _F.to_tensor(inpt)
else: else:
...@@ -47,6 +50,10 @@ class PILToTensor(Transform): ...@@ -47,6 +50,10 @@ class PILToTensor(Transform):
class ToPILImage(Transform): class ToPILImage(Transform):
# Updated transformed types for ToPILImage
_transformed_types = (torch.Tensor, features._Feature, PIL.Image.Image, np.ndarray)
def __init__(self, mode: Optional[str] = None) -> None: def __init__(self, mode: Optional[str] = None) -> None:
warnings.warn( warnings.warn(
"The transform `ToPILImage()` is deprecated and will be removed in a future release. " "The transform `ToPILImage()` is deprecated and will be removed in a future release. "
...@@ -56,8 +63,6 @@ class ToPILImage(Transform): ...@@ -56,8 +63,6 @@ class ToPILImage(Transform):
self.mode = mode self.mode = mode
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
# TODO: Transforms allows to pass only (torch.Tensor, _Feature, PIL.Image.Image)
# so input as np.ndarray is not possible. We need to make it possible
if is_simple_tensor(inpt) or isinstance(inpt, (features.Image, np.ndarray)): if is_simple_tensor(inpt) or isinstance(inpt, (features.Image, np.ndarray)):
return _F.to_pil_image(inpt, mode=self.mode) return _F.to_pil_image(inpt, mode=self.mode)
else: else:
......
import enum import enum
from typing import Any, Dict from typing import Any, Dict, Tuple, Type
import PIL.Image import PIL.Image
import torch import torch
...@@ -10,6 +10,10 @@ from torchvision.utils import _log_api_usage_once ...@@ -10,6 +10,10 @@ from torchvision.utils import _log_api_usage_once
class Transform(nn.Module): class Transform(nn.Module):
# Class attribute defining transformed types. Other types are passed-through without any transformation
_transformed_types: Tuple[Type, ...] = (torch.Tensor, _Feature, PIL.Image.Image)
def __init__(self) -> None: def __init__(self) -> None:
super().__init__() super().__init__()
_log_api_usage_once(self) _log_api_usage_once(self)
...@@ -26,9 +30,8 @@ class Transform(nn.Module): ...@@ -26,9 +30,8 @@ class Transform(nn.Module):
params = self._get_params(sample) params = self._get_params(sample)
flat_inputs, spec = tree_flatten(sample) flat_inputs, spec = tree_flatten(sample)
transformed_types = (torch.Tensor, _Feature, PIL.Image.Image)
flat_outputs = [ flat_outputs = [
self._transform(inpt, params) if isinstance(inpt, transformed_types) else inpt for inpt in flat_inputs self._transform(inpt, params) if isinstance(inpt, self._transformed_types) else inpt for inpt in flat_inputs
] ]
return tree_unflatten(flat_outputs, spec) return tree_unflatten(flat_outputs, spec)
......
...@@ -2,6 +2,8 @@ from typing import Any, Dict ...@@ -2,6 +2,8 @@ from typing import Any, Dict
import numpy as np import numpy as np
import PIL.Image import PIL.Image
import torch
from torchvision.prototype import features from torchvision.prototype import features
from torchvision.prototype.transforms import functional as F, Transform from torchvision.prototype.transforms import functional as F, Transform
...@@ -40,13 +42,15 @@ class LabelToOneHot(Transform): ...@@ -40,13 +42,15 @@ class LabelToOneHot(Transform):
class ToImageTensor(Transform): class ToImageTensor(Transform):
# Updated transformed types for ToImageTensor
_transformed_types = (torch.Tensor, features._Feature, PIL.Image.Image, np.ndarray)
def __init__(self, *, copy: bool = False) -> None: def __init__(self, *, copy: bool = False) -> None:
super().__init__() super().__init__()
self.copy = copy self.copy = copy
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
# TODO: Transforms allows to pass only (torch.Tensor, _Feature, PIL.Image.Image)
# so input as np.ndarray is not possible. We need to make it possible
if isinstance(inpt, (features.Image, PIL.Image.Image, np.ndarray)) or is_simple_tensor(inpt): if isinstance(inpt, (features.Image, PIL.Image.Image, np.ndarray)) or is_simple_tensor(inpt):
output = F.to_image_tensor(inpt, copy=self.copy) output = F.to_image_tensor(inpt, copy=self.copy)
return features.Image(output) return features.Image(output)
...@@ -55,13 +59,15 @@ class ToImageTensor(Transform): ...@@ -55,13 +59,15 @@ class ToImageTensor(Transform):
class ToImagePIL(Transform): class ToImagePIL(Transform):
# Updated transformed types for ToImagePIL
_transformed_types = (torch.Tensor, features._Feature, PIL.Image.Image, np.ndarray)
def __init__(self, *, copy: bool = False) -> None: def __init__(self, *, copy: bool = False) -> None:
super().__init__() super().__init__()
self.copy = copy self.copy = copy
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
# TODO: Transforms allows to pass only (torch.Tensor, _Feature, PIL.Image.Image)
# so input as np.ndarray is not possible. We need to make it possible
if isinstance(inpt, (features.Image, PIL.Image.Image, np.ndarray)) or is_simple_tensor(inpt): if isinstance(inpt, (features.Image, PIL.Image.Image, np.ndarray)) or is_simple_tensor(inpt):
return F.to_image_pil(inpt, copy=self.copy) return F.to_image_pil(inpt, copy=self.copy)
else: else:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment