Unverified Commit 0eb8aabd authored by vfdev's avatar vfdev Committed by GitHub
Browse files

[proto] Fixes Transform._transformed_types and torch.Tensor (#6487)

* Fixes unexpected behaviour with Transform._transformed_types and torch.Tensor

* Make code consistent to has_any, has_all implementation

* Fixed failing flake8 check
parent f9966d22
......@@ -225,9 +225,18 @@ class TestSmoke:
)
]
)
def test_convertolor_space(self, transform, input):
def test_convert_color_space(self, transform, input):
transform(input)
def test_convert_color_space_unsupported_types(self):
transform = transforms.ConvertColorSpace(
color_space=features.ColorSpace.RGB, old_color_space=features.ColorSpace.GRAY
)
for inpt in [make_bounding_box(format="XYXY"), make_segmentation_mask()]:
output = transform(inpt)
assert output is inpt
@pytest.mark.parametrize("p", [0.0, 1.0])
class TestRandomHorizontalFlip:
......
......@@ -3,7 +3,6 @@ from typing import Any, Dict, Optional
import numpy as np
import PIL.Image
import torch
import torchvision.prototype.transforms.functional as F
from torchvision.prototype import features
from torchvision.prototype.features import ColorSpace
......@@ -18,7 +17,7 @@ from ._utils import is_simple_tensor
class ToTensor(Transform):
# Updated transformed types for ToTensor
_transformed_types = (torch.Tensor, features._Feature, PIL.Image.Image, np.ndarray)
_transformed_types = (is_simple_tensor, features._Feature, PIL.Image.Image, np.ndarray)
def __init__(self) -> None:
warnings.warn(
......@@ -52,7 +51,7 @@ class PILToTensor(Transform):
class ToPILImage(Transform):
# Updated transformed types for ToPILImage
_transformed_types = (torch.Tensor, features._Feature, PIL.Image.Image, np.ndarray)
_transformed_types = (is_simple_tensor, features._Feature, PIL.Image.Image, np.ndarray)
def __init__(self, mode: Optional[str] = None) -> None:
warnings.warn(
......
......@@ -42,7 +42,7 @@ class ConvertImageDtype(Transform):
class ConvertColorSpace(Transform):
# F.convert_color_space does NOT handle `_Feature`'s in general
_transformed_types = (torch.Tensor, features.Image, PIL.Image.Image)
_transformed_types = (is_simple_tensor, features.Image, PIL.Image.Image)
def __init__(
self,
......
import enum
from typing import Any, Dict, Tuple, Type
from typing import Any, Callable, Dict, Tuple, Type, Union
import PIL.Image
import torch
from torch import nn
from torch.utils._pytree import tree_flatten, tree_unflatten
from torchvision.prototype.features import _Feature
from torchvision.prototype.transforms._utils import _isinstance, is_simple_tensor
from torchvision.utils import _log_api_usage_once
class Transform(nn.Module):
# Class attribute defining transformed types. Other types are passed-through without any transformation
_transformed_types: Tuple[Type, ...] = (torch.Tensor, _Feature, PIL.Image.Image)
_transformed_types: Tuple[Union[Type, Callable[[Any], bool]], ...] = (is_simple_tensor, _Feature, PIL.Image.Image)
def __init__(self) -> None:
super().__init__()
......@@ -31,7 +32,8 @@ class Transform(nn.Module):
flat_inputs, spec = tree_flatten(sample)
flat_outputs = [
self._transform(inpt, params) if isinstance(inpt, self._transformed_types) else inpt for inpt in flat_inputs
self._transform(inpt, params) if _isinstance(inpt, self._transformed_types) else inpt
for inpt in flat_inputs
]
return tree_unflatten(flat_outputs, spec)
......
......@@ -3,7 +3,6 @@ from typing import Any, Dict, Optional
import numpy as np
import PIL.Image
import torch
from torch.nn.functional import one_hot
from torchvision.prototype import features
from torchvision.prototype.transforms import functional as F, Transform
......@@ -44,7 +43,7 @@ class LabelToOneHot(Transform):
class ToImageTensor(Transform):
# Updated transformed types for ToImageTensor
_transformed_types = (torch.Tensor, features._Feature, PIL.Image.Image, np.ndarray)
_transformed_types = (is_simple_tensor, features._Feature, PIL.Image.Image, np.ndarray)
def __init__(self, *, copy: bool = False) -> None:
super().__init__()
......@@ -61,7 +60,7 @@ class ToImageTensor(Transform):
class ToImagePIL(Transform):
# Updated transformed types for ToImagePIL
_transformed_types = (torch.Tensor, features._Feature, PIL.Image.Image, np.ndarray)
_transformed_types = (is_simple_tensor, features._Feature, PIL.Image.Image, np.ndarray)
def __init__(self, *, mode: Optional[str] = None) -> None:
super().__init__()
......
......@@ -45,12 +45,18 @@ def query_chw(sample: Any) -> Tuple[int, int, int]:
return chws.pop()
def _isinstance(obj: Any, types_or_checks: Tuple[Union[Type, Callable[[Any], bool]], ...]) -> bool:
for type_or_check in types_or_checks:
if isinstance(obj, type_or_check) if isinstance(type_or_check, type) else type_or_check(obj):
return True
return False
def has_any(sample: Any, *types_or_checks: Union[Type, Callable[[Any], bool]]) -> bool:
flat_sample, _ = tree_flatten(sample)
for type_or_check in types_or_checks:
for obj in flat_sample:
if isinstance(obj, type_or_check) if isinstance(type_or_check, type) else type_or_check(obj):
return True
for obj in flat_sample:
if _isinstance(obj, types_or_checks):
return True
return False
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment