test_transforms_v2_utils.py 3.97 KB
Newer Older
1
2
3
4
5
import PIL.Image
import pytest

import torch

6
import torchvision.transforms.v2.utils
7
from common_utils import make_bounding_box, make_detection_mask, make_image
8

9
10
11
from torchvision import datapoints
from torchvision.transforms.v2.functional import to_image_pil
from torchvision.transforms.v2.utils import has_all, has_any
12
13


14
IMAGE = make_image(color_space="RGB")
15
BOUNDING_BOX = make_bounding_box(format=datapoints.BoundingBoxFormat.XYXY, spatial_size=IMAGE.spatial_size)
16
MASK = make_detection_mask(size=IMAGE.spatial_size)
17
18
19
20
21


@pytest.mark.parametrize(
    ("sample", "types", "expected"),
    [
22
23
24
25
26
27
28
29
30
        ((IMAGE, BOUNDING_BOX, MASK), (datapoints.Image,), True),
        ((IMAGE, BOUNDING_BOX, MASK), (datapoints.BoundingBox,), True),
        ((IMAGE, BOUNDING_BOX, MASK), (datapoints.Mask,), True),
        ((IMAGE, BOUNDING_BOX, MASK), (datapoints.Image, datapoints.BoundingBox), True),
        ((IMAGE, BOUNDING_BOX, MASK), (datapoints.Image, datapoints.Mask), True),
        ((IMAGE, BOUNDING_BOX, MASK), (datapoints.BoundingBox, datapoints.Mask), True),
        ((MASK,), (datapoints.Image, datapoints.BoundingBox), False),
        ((BOUNDING_BOX,), (datapoints.Image, datapoints.Mask), False),
        ((IMAGE,), (datapoints.BoundingBox, datapoints.Mask), False),
31
        (
32
            (IMAGE, BOUNDING_BOX, MASK),
33
            (datapoints.Image, datapoints.BoundingBox, datapoints.Mask),
34
35
            True,
        ),
36
37
        ((), (datapoints.Image, datapoints.BoundingBox, datapoints.Mask), False),
        ((IMAGE, BOUNDING_BOX, MASK), (lambda obj: isinstance(obj, datapoints.Image),), True),
38
39
        ((IMAGE, BOUNDING_BOX, MASK), (lambda _: False,), False),
        ((IMAGE, BOUNDING_BOX, MASK), (lambda _: True,), True),
40
        ((IMAGE,), (datapoints.Image, PIL.Image.Image, torchvision.transforms.v2.utils.is_simple_tensor), True),
41
42
        (
            (torch.Tensor(IMAGE),),
43
            (datapoints.Image, PIL.Image.Image, torchvision.transforms.v2.utils.is_simple_tensor),
44
45
46
47
            True,
        ),
        (
            (to_image_pil(IMAGE),),
48
            (datapoints.Image, PIL.Image.Image, torchvision.transforms.v2.utils.is_simple_tensor),
49
50
            True,
        ),
51
52
53
54
55
56
57
58
59
    ],
)
def test_has_any(sample, types, expected):
    assert has_any(sample, *types) is expected


@pytest.mark.parametrize(
    ("sample", "types", "expected"),
    [
60
61
62
63
64
65
        ((IMAGE, BOUNDING_BOX, MASK), (datapoints.Image,), True),
        ((IMAGE, BOUNDING_BOX, MASK), (datapoints.BoundingBox,), True),
        ((IMAGE, BOUNDING_BOX, MASK), (datapoints.Mask,), True),
        ((IMAGE, BOUNDING_BOX, MASK), (datapoints.Image, datapoints.BoundingBox), True),
        ((IMAGE, BOUNDING_BOX, MASK), (datapoints.Image, datapoints.Mask), True),
        ((IMAGE, BOUNDING_BOX, MASK), (datapoints.BoundingBox, datapoints.Mask), True),
66
        (
67
            (IMAGE, BOUNDING_BOX, MASK),
68
            (datapoints.Image, datapoints.BoundingBox, datapoints.Mask),
69
70
            True,
        ),
71
72
73
        ((BOUNDING_BOX, MASK), (datapoints.Image, datapoints.BoundingBox), False),
        ((BOUNDING_BOX, MASK), (datapoints.Image, datapoints.Mask), False),
        ((IMAGE, MASK), (datapoints.BoundingBox, datapoints.Mask), False),
74
        (
75
            (IMAGE, BOUNDING_BOX, MASK),
76
            (datapoints.Image, datapoints.BoundingBox, datapoints.Mask),
77
78
            True,
        ),
79
80
81
        ((BOUNDING_BOX, MASK), (datapoints.Image, datapoints.BoundingBox, datapoints.Mask), False),
        ((IMAGE, MASK), (datapoints.Image, datapoints.BoundingBox, datapoints.Mask), False),
        ((IMAGE, BOUNDING_BOX), (datapoints.Image, datapoints.BoundingBox, datapoints.Mask), False),
82
        (
83
            (IMAGE, BOUNDING_BOX, MASK),
84
            (lambda obj: isinstance(obj, (datapoints.Image, datapoints.BoundingBox, datapoints.Mask)),),
85
86
            True,
        ),
87
88
        ((IMAGE, BOUNDING_BOX, MASK), (lambda _: False,), False),
        ((IMAGE, BOUNDING_BOX, MASK), (lambda _: True,), True),
89
90
91
92
    ],
)
def test_has_all(sample, types, expected):
    assert has_all(sample, *types) is expected