Unverified Commit 330b6c9b authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

expand has_any and has_all to also accept check callables (#6447)

* expand has_any and has_all to also accept check callables

* add test and fix has_all

* add support for simple tensor images to CutMix, MixUp and RandomIoUCrop

* remove TODO

* remove pythonic syntax sugar

* simplify

* use concreate examples in test rather than abstract ones

* simplify further
parent 80c197ad
import PIL.Image
import pytest
import torch
from test_prototype_transforms_functional import make_bounding_box, make_image, make_segmentation_mask
from torchvision.prototype import features
from torchvision.prototype.transforms._utils import has_all, has_any, is_simple_tensor
from torchvision.prototype.transforms.functional import to_image_pil
IMAGE = make_image(color_space=features.ColorSpace.RGB)
BOUNDING_BOX = make_bounding_box(format=features.BoundingBoxFormat.XYXY, image_size=IMAGE.image_size)
SEGMENTATION_MASK = make_segmentation_mask(size=IMAGE.image_size)
@pytest.mark.parametrize(
("sample", "types", "expected"),
[
((IMAGE, BOUNDING_BOX, SEGMENTATION_MASK), (features.Image,), True),
((IMAGE, BOUNDING_BOX, SEGMENTATION_MASK), (features.BoundingBox,), True),
((IMAGE, BOUNDING_BOX, SEGMENTATION_MASK), (features.SegmentationMask,), True),
((IMAGE, BOUNDING_BOX, SEGMENTATION_MASK), (features.Image, features.BoundingBox), True),
((IMAGE, BOUNDING_BOX, SEGMENTATION_MASK), (features.Image, features.SegmentationMask), True),
((IMAGE, BOUNDING_BOX, SEGMENTATION_MASK), (features.BoundingBox, features.SegmentationMask), True),
((SEGMENTATION_MASK,), (features.Image, features.BoundingBox), False),
((BOUNDING_BOX,), (features.Image, features.SegmentationMask), False),
((IMAGE,), (features.BoundingBox, features.SegmentationMask), False),
(
(IMAGE, BOUNDING_BOX, SEGMENTATION_MASK),
(features.Image, features.BoundingBox, features.SegmentationMask),
True,
),
((), (features.Image, features.BoundingBox, features.SegmentationMask), False),
((IMAGE, BOUNDING_BOX, SEGMENTATION_MASK), (lambda obj: isinstance(obj, features.Image),), True),
((IMAGE, BOUNDING_BOX, SEGMENTATION_MASK), (lambda _: False,), False),
((IMAGE, BOUNDING_BOX, SEGMENTATION_MASK), (lambda _: True,), True),
((IMAGE,), (features.Image, PIL.Image.Image, is_simple_tensor), True),
((torch.Tensor(IMAGE),), (features.Image, PIL.Image.Image, is_simple_tensor), True),
((to_image_pil(IMAGE),), (features.Image, PIL.Image.Image, is_simple_tensor), True),
],
)
def test_has_any(sample, types, expected):
assert has_any(sample, *types) is expected
@pytest.mark.parametrize(
("sample", "types", "expected"),
[
((IMAGE, BOUNDING_BOX, SEGMENTATION_MASK), (features.Image,), True),
((IMAGE, BOUNDING_BOX, SEGMENTATION_MASK), (features.BoundingBox,), True),
((IMAGE, BOUNDING_BOX, SEGMENTATION_MASK), (features.SegmentationMask,), True),
((IMAGE, BOUNDING_BOX, SEGMENTATION_MASK), (features.Image, features.BoundingBox), True),
((IMAGE, BOUNDING_BOX, SEGMENTATION_MASK), (features.Image, features.SegmentationMask), True),
((IMAGE, BOUNDING_BOX, SEGMENTATION_MASK), (features.BoundingBox, features.SegmentationMask), True),
(
(IMAGE, BOUNDING_BOX, SEGMENTATION_MASK),
(features.Image, features.BoundingBox, features.SegmentationMask),
True,
),
((BOUNDING_BOX, SEGMENTATION_MASK), (features.Image, features.BoundingBox), False),
((BOUNDING_BOX, SEGMENTATION_MASK), (features.Image, features.SegmentationMask), False),
((IMAGE, SEGMENTATION_MASK), (features.BoundingBox, features.SegmentationMask), False),
(
(IMAGE, BOUNDING_BOX, SEGMENTATION_MASK),
(features.Image, features.BoundingBox, features.SegmentationMask),
True,
),
((BOUNDING_BOX, SEGMENTATION_MASK), (features.Image, features.BoundingBox, features.SegmentationMask), False),
((IMAGE, SEGMENTATION_MASK), (features.Image, features.BoundingBox, features.SegmentationMask), False),
((IMAGE, BOUNDING_BOX), (features.Image, features.BoundingBox, features.SegmentationMask), False),
(
(IMAGE, BOUNDING_BOX, SEGMENTATION_MASK),
(lambda obj: isinstance(obj, (features.Image, features.BoundingBox, features.SegmentationMask)),),
True,
),
((IMAGE, BOUNDING_BOX, SEGMENTATION_MASK), (lambda _: False,), False),
((IMAGE, BOUNDING_BOX, SEGMENTATION_MASK), (lambda _: True,), True),
],
)
def test_has_all(sample, types, expected):
assert has_all(sample, *types) is expected
...@@ -9,7 +9,7 @@ from torchvision.prototype import features ...@@ -9,7 +9,7 @@ from torchvision.prototype import features
from torchvision.prototype.transforms import functional as F from torchvision.prototype.transforms import functional as F
from ._transform import _RandomApplyTransform from ._transform import _RandomApplyTransform
from ._utils import get_image_dimensions, has_all, has_any, is_simple_tensor, query_image from ._utils import get_image_dimensions, has_any, is_simple_tensor, query_image
class RandomErasing(_RandomApplyTransform): class RandomErasing(_RandomApplyTransform):
...@@ -105,7 +105,9 @@ class _BaseMixupCutmix(_RandomApplyTransform): ...@@ -105,7 +105,9 @@ class _BaseMixupCutmix(_RandomApplyTransform):
def forward(self, *inpts: Any) -> Any: def forward(self, *inpts: Any) -> Any:
sample = inpts if len(inpts) > 1 else inpts[0] sample = inpts if len(inpts) > 1 else inpts[0]
if not has_all(sample, features.Image, features.OneHotLabel): if not (
has_any(sample, features.Image, PIL.Image.Image, is_simple_tensor) and has_any(sample, features.OneHotLabel)
):
raise TypeError(f"{type(self).__name__}() is only defined for Image's *and* OneHotLabel's.") raise TypeError(f"{type(self).__name__}() is only defined for Image's *and* OneHotLabel's.")
if has_any(sample, features.BoundingBox, features.SegmentationMask, features.Label): if has_any(sample, features.BoundingBox, features.SegmentationMask, features.Label):
raise TypeError( raise TypeError(
......
...@@ -719,10 +719,9 @@ class RandomIoUCrop(Transform): ...@@ -719,10 +719,9 @@ class RandomIoUCrop(Transform):
def forward(self, *inputs: Any) -> Any: def forward(self, *inputs: Any) -> Any:
sample = inputs if len(inputs) > 1 else inputs[0] sample = inputs if len(inputs) > 1 else inputs[0]
# TODO: Allow image to be a torch.Tensor
if not ( if not (
has_all(sample, features.BoundingBox) has_all(sample, features.BoundingBox)
and has_any(sample, PIL.Image.Image, features.Image) and has_any(sample, PIL.Image.Image, features.Image, is_simple_tensor)
and has_any(sample, features.Label, features.OneHotLabel) and has_any(sample, features.Label, features.OneHotLabel)
): ):
raise TypeError( raise TypeError(
......
from typing import Any, Tuple, Type, Union from typing import Any, Callable, Tuple, Type, Union
import PIL.Image import PIL.Image
import torch import torch
...@@ -39,14 +39,24 @@ def get_image_dimensions(image: Union[PIL.Image.Image, torch.Tensor, features.Im ...@@ -39,14 +39,24 @@ def get_image_dimensions(image: Union[PIL.Image.Image, torch.Tensor, features.Im
return channels, height, width return channels, height, width
def has_any(sample: Any, *types: Type) -> bool: def has_any(sample: Any, *types_or_checks: Union[Type, Callable[[Any], bool]]) -> bool:
flat_sample, _ = tree_flatten(sample) flat_sample, _ = tree_flatten(sample)
return any(issubclass(type(obj), types) for obj in flat_sample) for type_or_check in types_or_checks:
for obj in flat_sample:
if isinstance(obj, type_or_check) if isinstance(type_or_check, type) else type_or_check(obj):
return True
return False
def has_all(sample: Any, *types: Type) -> bool: def has_all(sample: Any, *types_or_checks: Union[Type, Callable[[Any], bool]]) -> bool:
flat_sample, _ = tree_flatten(sample) flat_sample, _ = tree_flatten(sample)
return not bool(set(types) - set([type(obj) for obj in flat_sample])) for type_or_check in types_or_checks:
for obj in flat_sample:
if isinstance(obj, type_or_check) if isinstance(type_or_check, type) else type_or_check(obj):
break
else:
return False
return True
def is_simple_tensor(inpt: Any) -> bool: def is_simple_tensor(inpt: Any) -> bool:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment