test_transforms_v2_utils.py 4.01 KB
Newer Older
1
2
3
4
5
import PIL.Image
import pytest

import torch

Nicolas Hug's avatar
Nicolas Hug committed
6
import torchvision.transforms.v2._utils
7
from common_utils import DEFAULT_SIZE, make_bounding_boxes, make_detection_mask, make_image
8

9
from torchvision import tv_tensors
Nicolas Hug's avatar
Nicolas Hug committed
10
from torchvision.transforms.v2._utils import has_all, has_any
11
from torchvision.transforms.v2.functional import to_pil_image
12
13


Philip Meier's avatar
Philip Meier committed
14
IMAGE = make_image(DEFAULT_SIZE, color_space="RGB")
15
BOUNDING_BOX = make_bounding_boxes(DEFAULT_SIZE, format=tv_tensors.BoundingBoxFormat.XYXY)
Philip Meier's avatar
Philip Meier committed
16
MASK = make_detection_mask(DEFAULT_SIZE)
17
18
19
20
21


@pytest.mark.parametrize(
    ("sample", "types", "expected"),
    [
22
23
24
25
26
27
28
29
30
        ((IMAGE, BOUNDING_BOX, MASK), (tv_tensors.Image,), True),
        ((IMAGE, BOUNDING_BOX, MASK), (tv_tensors.BoundingBoxes,), True),
        ((IMAGE, BOUNDING_BOX, MASK), (tv_tensors.Mask,), True),
        ((IMAGE, BOUNDING_BOX, MASK), (tv_tensors.Image, tv_tensors.BoundingBoxes), True),
        ((IMAGE, BOUNDING_BOX, MASK), (tv_tensors.Image, tv_tensors.Mask), True),
        ((IMAGE, BOUNDING_BOX, MASK), (tv_tensors.BoundingBoxes, tv_tensors.Mask), True),
        ((MASK,), (tv_tensors.Image, tv_tensors.BoundingBoxes), False),
        ((BOUNDING_BOX,), (tv_tensors.Image, tv_tensors.Mask), False),
        ((IMAGE,), (tv_tensors.BoundingBoxes, tv_tensors.Mask), False),
31
        (
32
            (IMAGE, BOUNDING_BOX, MASK),
33
            (tv_tensors.Image, tv_tensors.BoundingBoxes, tv_tensors.Mask),
34
35
            True,
        ),
36
37
        ((), (tv_tensors.Image, tv_tensors.BoundingBoxes, tv_tensors.Mask), False),
        ((IMAGE, BOUNDING_BOX, MASK), (lambda obj: isinstance(obj, tv_tensors.Image),), True),
38
39
        ((IMAGE, BOUNDING_BOX, MASK), (lambda _: False,), False),
        ((IMAGE, BOUNDING_BOX, MASK), (lambda _: True,), True),
40
        ((IMAGE,), (tv_tensors.Image, PIL.Image.Image, torchvision.transforms.v2._utils.is_pure_tensor), True),
41
42
        (
            (torch.Tensor(IMAGE),),
43
            (tv_tensors.Image, PIL.Image.Image, torchvision.transforms.v2._utils.is_pure_tensor),
44
45
46
            True,
        ),
        (
47
            (to_pil_image(IMAGE),),
48
            (tv_tensors.Image, PIL.Image.Image, torchvision.transforms.v2._utils.is_pure_tensor),
49
50
            True,
        ),
51
52
53
54
55
56
57
58
59
    ],
)
def test_has_any(sample, types, expected):
    assert has_any(sample, *types) is expected


@pytest.mark.parametrize(
    ("sample", "types", "expected"),
    [
60
61
62
63
64
65
        ((IMAGE, BOUNDING_BOX, MASK), (tv_tensors.Image,), True),
        ((IMAGE, BOUNDING_BOX, MASK), (tv_tensors.BoundingBoxes,), True),
        ((IMAGE, BOUNDING_BOX, MASK), (tv_tensors.Mask,), True),
        ((IMAGE, BOUNDING_BOX, MASK), (tv_tensors.Image, tv_tensors.BoundingBoxes), True),
        ((IMAGE, BOUNDING_BOX, MASK), (tv_tensors.Image, tv_tensors.Mask), True),
        ((IMAGE, BOUNDING_BOX, MASK), (tv_tensors.BoundingBoxes, tv_tensors.Mask), True),
66
        (
67
            (IMAGE, BOUNDING_BOX, MASK),
68
            (tv_tensors.Image, tv_tensors.BoundingBoxes, tv_tensors.Mask),
69
70
            True,
        ),
71
72
73
        ((BOUNDING_BOX, MASK), (tv_tensors.Image, tv_tensors.BoundingBoxes), False),
        ((BOUNDING_BOX, MASK), (tv_tensors.Image, tv_tensors.Mask), False),
        ((IMAGE, MASK), (tv_tensors.BoundingBoxes, tv_tensors.Mask), False),
74
        (
75
            (IMAGE, BOUNDING_BOX, MASK),
76
            (tv_tensors.Image, tv_tensors.BoundingBoxes, tv_tensors.Mask),
77
78
            True,
        ),
79
80
81
        ((BOUNDING_BOX, MASK), (tv_tensors.Image, tv_tensors.BoundingBoxes, tv_tensors.Mask), False),
        ((IMAGE, MASK), (tv_tensors.Image, tv_tensors.BoundingBoxes, tv_tensors.Mask), False),
        ((IMAGE, BOUNDING_BOX), (tv_tensors.Image, tv_tensors.BoundingBoxes, tv_tensors.Mask), False),
82
        (
83
            (IMAGE, BOUNDING_BOX, MASK),
84
            (lambda obj: isinstance(obj, (tv_tensors.Image, tv_tensors.BoundingBoxes, tv_tensors.Mask)),),
85
86
            True,
        ),
87
88
        ((IMAGE, BOUNDING_BOX, MASK), (lambda _: False,), False),
        ((IMAGE, BOUNDING_BOX, MASK), (lambda _: True,), True),
89
90
91
92
    ],
)
def test_has_all(sample, types, expected):
    assert has_all(sample, *types) is expected