Unverified Commit 0eb8aabd authored by vfdev's avatar vfdev Committed by GitHub
Browse files

[proto] Fixes Transform._transformed_types and torch.Tensor (#6487)

* Fixes unexpected behaviour with Transform._transformed_types and torch.Tensor

* Make code consistent to has_any, has_all implementation

* Fixed failing flake8 check
parent f9966d22
...@@ -225,9 +225,18 @@ class TestSmoke: ...@@ -225,9 +225,18 @@ class TestSmoke:
) )
] ]
) )
def test_convertolor_space(self, transform, input): def test_convert_color_space(self, transform, input):
transform(input) transform(input)
def test_convert_color_space_unsupported_types(self):
transform = transforms.ConvertColorSpace(
color_space=features.ColorSpace.RGB, old_color_space=features.ColorSpace.GRAY
)
for inpt in [make_bounding_box(format="XYXY"), make_segmentation_mask()]:
output = transform(inpt)
assert output is inpt
@pytest.mark.parametrize("p", [0.0, 1.0]) @pytest.mark.parametrize("p", [0.0, 1.0])
class TestRandomHorizontalFlip: class TestRandomHorizontalFlip:
......
...@@ -3,7 +3,6 @@ from typing import Any, Dict, Optional ...@@ -3,7 +3,6 @@ from typing import Any, Dict, Optional
import numpy as np import numpy as np
import PIL.Image import PIL.Image
import torch
import torchvision.prototype.transforms.functional as F import torchvision.prototype.transforms.functional as F
from torchvision.prototype import features from torchvision.prototype import features
from torchvision.prototype.features import ColorSpace from torchvision.prototype.features import ColorSpace
...@@ -18,7 +17,7 @@ from ._utils import is_simple_tensor ...@@ -18,7 +17,7 @@ from ._utils import is_simple_tensor
class ToTensor(Transform): class ToTensor(Transform):
# Updated transformed types for ToTensor # Updated transformed types for ToTensor
_transformed_types = (torch.Tensor, features._Feature, PIL.Image.Image, np.ndarray) _transformed_types = (is_simple_tensor, features._Feature, PIL.Image.Image, np.ndarray)
def __init__(self) -> None: def __init__(self) -> None:
warnings.warn( warnings.warn(
...@@ -52,7 +51,7 @@ class PILToTensor(Transform): ...@@ -52,7 +51,7 @@ class PILToTensor(Transform):
class ToPILImage(Transform): class ToPILImage(Transform):
# Updated transformed types for ToPILImage # Updated transformed types for ToPILImage
_transformed_types = (torch.Tensor, features._Feature, PIL.Image.Image, np.ndarray) _transformed_types = (is_simple_tensor, features._Feature, PIL.Image.Image, np.ndarray)
def __init__(self, mode: Optional[str] = None) -> None: def __init__(self, mode: Optional[str] = None) -> None:
warnings.warn( warnings.warn(
......
...@@ -42,7 +42,7 @@ class ConvertImageDtype(Transform): ...@@ -42,7 +42,7 @@ class ConvertImageDtype(Transform):
class ConvertColorSpace(Transform): class ConvertColorSpace(Transform):
# F.convert_color_space does NOT handle `_Feature`'s in general # F.convert_color_space does NOT handle `_Feature`'s in general
_transformed_types = (torch.Tensor, features.Image, PIL.Image.Image) _transformed_types = (is_simple_tensor, features.Image, PIL.Image.Image)
def __init__( def __init__(
self, self,
......
import enum import enum
from typing import Any, Dict, Tuple, Type from typing import Any, Callable, Dict, Tuple, Type, Union
import PIL.Image import PIL.Image
import torch import torch
from torch import nn from torch import nn
from torch.utils._pytree import tree_flatten, tree_unflatten from torch.utils._pytree import tree_flatten, tree_unflatten
from torchvision.prototype.features import _Feature from torchvision.prototype.features import _Feature
from torchvision.prototype.transforms._utils import _isinstance, is_simple_tensor
from torchvision.utils import _log_api_usage_once from torchvision.utils import _log_api_usage_once
class Transform(nn.Module): class Transform(nn.Module):
# Class attribute defining transformed types. Other types are passed-through without any transformation # Class attribute defining transformed types. Other types are passed-through without any transformation
_transformed_types: Tuple[Type, ...] = (torch.Tensor, _Feature, PIL.Image.Image) _transformed_types: Tuple[Union[Type, Callable[[Any], bool]], ...] = (is_simple_tensor, _Feature, PIL.Image.Image)
def __init__(self) -> None: def __init__(self) -> None:
super().__init__() super().__init__()
...@@ -31,7 +32,8 @@ class Transform(nn.Module): ...@@ -31,7 +32,8 @@ class Transform(nn.Module):
flat_inputs, spec = tree_flatten(sample) flat_inputs, spec = tree_flatten(sample)
flat_outputs = [ flat_outputs = [
self._transform(inpt, params) if isinstance(inpt, self._transformed_types) else inpt for inpt in flat_inputs self._transform(inpt, params) if _isinstance(inpt, self._transformed_types) else inpt
for inpt in flat_inputs
] ]
return tree_unflatten(flat_outputs, spec) return tree_unflatten(flat_outputs, spec)
......
...@@ -3,7 +3,6 @@ from typing import Any, Dict, Optional ...@@ -3,7 +3,6 @@ from typing import Any, Dict, Optional
import numpy as np import numpy as np
import PIL.Image import PIL.Image
import torch
from torch.nn.functional import one_hot from torch.nn.functional import one_hot
from torchvision.prototype import features from torchvision.prototype import features
from torchvision.prototype.transforms import functional as F, Transform from torchvision.prototype.transforms import functional as F, Transform
...@@ -44,7 +43,7 @@ class LabelToOneHot(Transform): ...@@ -44,7 +43,7 @@ class LabelToOneHot(Transform):
class ToImageTensor(Transform): class ToImageTensor(Transform):
# Updated transformed types for ToImageTensor # Updated transformed types for ToImageTensor
_transformed_types = (torch.Tensor, features._Feature, PIL.Image.Image, np.ndarray) _transformed_types = (is_simple_tensor, features._Feature, PIL.Image.Image, np.ndarray)
def __init__(self, *, copy: bool = False) -> None: def __init__(self, *, copy: bool = False) -> None:
super().__init__() super().__init__()
...@@ -61,7 +60,7 @@ class ToImageTensor(Transform): ...@@ -61,7 +60,7 @@ class ToImageTensor(Transform):
class ToImagePIL(Transform): class ToImagePIL(Transform):
# Updated transformed types for ToImagePIL # Updated transformed types for ToImagePIL
_transformed_types = (torch.Tensor, features._Feature, PIL.Image.Image, np.ndarray) _transformed_types = (is_simple_tensor, features._Feature, PIL.Image.Image, np.ndarray)
def __init__(self, *, mode: Optional[str] = None) -> None: def __init__(self, *, mode: Optional[str] = None) -> None:
super().__init__() super().__init__()
......
...@@ -45,12 +45,18 @@ def query_chw(sample: Any) -> Tuple[int, int, int]: ...@@ -45,12 +45,18 @@ def query_chw(sample: Any) -> Tuple[int, int, int]:
return chws.pop() return chws.pop()
def _isinstance(obj: Any, types_or_checks: Tuple[Union[Type, Callable[[Any], bool]], ...]) -> bool:
for type_or_check in types_or_checks:
if isinstance(obj, type_or_check) if isinstance(type_or_check, type) else type_or_check(obj):
return True
return False
def has_any(sample: Any, *types_or_checks: Union[Type, Callable[[Any], bool]]) -> bool: def has_any(sample: Any, *types_or_checks: Union[Type, Callable[[Any], bool]]) -> bool:
flat_sample, _ = tree_flatten(sample) flat_sample, _ = tree_flatten(sample)
for type_or_check in types_or_checks: for obj in flat_sample:
for obj in flat_sample: if _isinstance(obj, types_or_checks):
if isinstance(obj, type_or_check) if isinstance(type_or_check, type) else type_or_check(obj): return True
return True
return False return False
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment