test_transforms_v2_utils.py 4.01 KB
Newer Older
1
2
3
4
5
import PIL.Image
import pytest

import torch

Nicolas Hug's avatar
Nicolas Hug committed
6
import torchvision.transforms.v2._utils
7
from common_utils import DEFAULT_SIZE, make_bounding_boxes, make_detection_mask, make_image
8

9
from torchvision import datapoints
Nicolas Hug's avatar
Nicolas Hug committed
10
from torchvision.transforms.v2._utils import has_all, has_any
11
from torchvision.transforms.v2.functional import to_pil_image
12
13


Philip Meier's avatar
Philip Meier committed
14
IMAGE = make_image(DEFAULT_SIZE, color_space="RGB")
15
BOUNDING_BOX = make_bounding_boxes(DEFAULT_SIZE, format=datapoints.BoundingBoxFormat.XYXY)
Philip Meier's avatar
Philip Meier committed
16
MASK = make_detection_mask(DEFAULT_SIZE)
17
18
19
20
21


@pytest.mark.parametrize(
    ("sample", "types", "expected"),
    [
22
        ((IMAGE, BOUNDING_BOX, MASK), (datapoints.Image,), True),
23
        ((IMAGE, BOUNDING_BOX, MASK), (datapoints.BoundingBoxes,), True),
24
        ((IMAGE, BOUNDING_BOX, MASK), (datapoints.Mask,), True),
25
        ((IMAGE, BOUNDING_BOX, MASK), (datapoints.Image, datapoints.BoundingBoxes), True),
26
        ((IMAGE, BOUNDING_BOX, MASK), (datapoints.Image, datapoints.Mask), True),
27
28
        ((IMAGE, BOUNDING_BOX, MASK), (datapoints.BoundingBoxes, datapoints.Mask), True),
        ((MASK,), (datapoints.Image, datapoints.BoundingBoxes), False),
29
        ((BOUNDING_BOX,), (datapoints.Image, datapoints.Mask), False),
30
        ((IMAGE,), (datapoints.BoundingBoxes, datapoints.Mask), False),
31
        (
32
            (IMAGE, BOUNDING_BOX, MASK),
33
            (datapoints.Image, datapoints.BoundingBoxes, datapoints.Mask),
34
35
            True,
        ),
36
        ((), (datapoints.Image, datapoints.BoundingBoxes, datapoints.Mask), False),
37
        ((IMAGE, BOUNDING_BOX, MASK), (lambda obj: isinstance(obj, datapoints.Image),), True),
38
39
        ((IMAGE, BOUNDING_BOX, MASK), (lambda _: False,), False),
        ((IMAGE, BOUNDING_BOX, MASK), (lambda _: True,), True),
Nicolas Hug's avatar
Nicolas Hug committed
40
        ((IMAGE,), (datapoints.Image, PIL.Image.Image, torchvision.transforms.v2._utils.is_pure_tensor), True),
41
42
        (
            (torch.Tensor(IMAGE),),
Nicolas Hug's avatar
Nicolas Hug committed
43
            (datapoints.Image, PIL.Image.Image, torchvision.transforms.v2._utils.is_pure_tensor),
44
45
46
            True,
        ),
        (
47
            (to_pil_image(IMAGE),),
Nicolas Hug's avatar
Nicolas Hug committed
48
            (datapoints.Image, PIL.Image.Image, torchvision.transforms.v2._utils.is_pure_tensor),
49
50
            True,
        ),
51
52
53
54
55
56
57
58
59
    ],
)
def test_has_any(sample, types, expected):
    assert has_any(sample, *types) is expected


@pytest.mark.parametrize(
    ("sample", "types", "expected"),
    [
60
        ((IMAGE, BOUNDING_BOX, MASK), (datapoints.Image,), True),
61
        ((IMAGE, BOUNDING_BOX, MASK), (datapoints.BoundingBoxes,), True),
62
        ((IMAGE, BOUNDING_BOX, MASK), (datapoints.Mask,), True),
63
        ((IMAGE, BOUNDING_BOX, MASK), (datapoints.Image, datapoints.BoundingBoxes), True),
64
        ((IMAGE, BOUNDING_BOX, MASK), (datapoints.Image, datapoints.Mask), True),
65
        ((IMAGE, BOUNDING_BOX, MASK), (datapoints.BoundingBoxes, datapoints.Mask), True),
66
        (
67
            (IMAGE, BOUNDING_BOX, MASK),
68
            (datapoints.Image, datapoints.BoundingBoxes, datapoints.Mask),
69
70
            True,
        ),
71
        ((BOUNDING_BOX, MASK), (datapoints.Image, datapoints.BoundingBoxes), False),
72
        ((BOUNDING_BOX, MASK), (datapoints.Image, datapoints.Mask), False),
73
        ((IMAGE, MASK), (datapoints.BoundingBoxes, datapoints.Mask), False),
74
        (
75
            (IMAGE, BOUNDING_BOX, MASK),
76
            (datapoints.Image, datapoints.BoundingBoxes, datapoints.Mask),
77
78
            True,
        ),
79
80
81
        ((BOUNDING_BOX, MASK), (datapoints.Image, datapoints.BoundingBoxes, datapoints.Mask), False),
        ((IMAGE, MASK), (datapoints.Image, datapoints.BoundingBoxes, datapoints.Mask), False),
        ((IMAGE, BOUNDING_BOX), (datapoints.Image, datapoints.BoundingBoxes, datapoints.Mask), False),
82
        (
83
            (IMAGE, BOUNDING_BOX, MASK),
84
            (lambda obj: isinstance(obj, (datapoints.Image, datapoints.BoundingBoxes, datapoints.Mask)),),
85
86
            True,
        ),
87
88
        ((IMAGE, BOUNDING_BOX, MASK), (lambda _: False,), False),
        ((IMAGE, BOUNDING_BOX, MASK), (lambda _: True,), True),
89
90
91
92
    ],
)
def test_has_all(sample, types, expected):
    assert has_all(sample, *types) is expected