test_transforms_v2_consistency.py 8.06 KB
Newer Older
1
2
import importlib.machinery
import importlib.util
3
import random
4
from pathlib import Path
5
6

import pytest
7
8

import torch
9
import torchvision.transforms.v2 as v2_transforms
10
11
from common_utils import assert_equal
from torchvision import tv_tensors
12

13
from torchvision.transforms import functional as legacy_F
14
from torchvision.transforms.v2 import functional as prototype_F
Nicolas Hug's avatar
Nicolas Hug committed
15
from torchvision.transforms.v2._utils import _get_fill, query_size
16
from torchvision.transforms.v2.functional import to_pil_image
17

18
from transforms_v2_legacy_utils import make_bounding_boxes, make_detection_mask, make_image, make_segmentation_mask
19
20


21
def import_transforms_from_references(reference):
22
23
24
25
26
27
28
29
30
31
    HERE = Path(__file__).parent
    PROJECT_ROOT = HERE.parent

    loader = importlib.machinery.SourceFileLoader(
        "transforms", str(PROJECT_ROOT / "references" / reference / "transforms.py")
    )
    spec = importlib.util.spec_from_loader("transforms", loader)
    module = importlib.util.module_from_spec(spec)
    loader.exec_module(module)
    return module
32
33
34


det_transforms = import_transforms_from_references("detection")
35
36
37


class TestRefDetTransforms:
38
    def make_tv_tensors(self, with_mask=True):
39
40
41
        size = (600, 800)
        num_objects = 22

42
43
44
        def make_label(extra_dims, categories):
            return torch.randint(categories, extra_dims, dtype=torch.int64)

45
        pil_image = to_pil_image(make_image(size=size, color_space="RGB"))
46
        target = {
47
            "boxes": make_bounding_boxes(canvas_size=size, format="XYXY", batch_dims=(num_objects,), dtype=torch.float),
48
49
50
51
52
53
54
            "labels": make_label(extra_dims=(num_objects,), categories=80),
        }
        if with_mask:
            target["masks"] = make_detection_mask(size=size, num_objects=num_objects, dtype=torch.long)

        yield (pil_image, target)

55
        tensor_image = torch.Tensor(make_image(size=size, color_space="RGB", dtype=torch.float32))
56
        target = {
57
            "boxes": make_bounding_boxes(canvas_size=size, format="XYXY", batch_dims=(num_objects,), dtype=torch.float),
58
59
60
61
62
63
64
            "labels": make_label(extra_dims=(num_objects,), categories=80),
        }
        if with_mask:
            target["masks"] = make_detection_mask(size=size, num_objects=num_objects, dtype=torch.long)

        yield (tensor_image, target)

65
        tv_tensor_image = make_image(size=size, color_space="RGB", dtype=torch.float32)
66
        target = {
67
            "boxes": make_bounding_boxes(canvas_size=size, format="XYXY", batch_dims=(num_objects,), dtype=torch.float),
68
69
70
71
72
            "labels": make_label(extra_dims=(num_objects,), categories=80),
        }
        if with_mask:
            target["masks"] = make_detection_mask(size=size, num_objects=num_objects, dtype=torch.long)

73
        yield (tv_tensor_image, target)
74
75
76
77

    @pytest.mark.parametrize(
        "t_ref, t, data_kwargs",
        [
78
            (det_transforms.RandomHorizontalFlip(p=1.0), v2_transforms.RandomHorizontalFlip(p=1.0), {}),
79
80
81
82
83
            (
                det_transforms.RandomIoUCrop(),
                v2_transforms.Compose(
                    [
                        v2_transforms.RandomIoUCrop(),
84
                        v2_transforms.SanitizeBoundingBoxes(labels_getter=lambda sample: sample[1]["labels"]),
85
86
87
88
                    ]
                ),
                {"with_mask": False},
            ),
89
            (det_transforms.RandomZoomOut(), v2_transforms.RandomZoomOut(), {"with_mask": False}),
90
            (det_transforms.ScaleJitter((1024, 1024)), v2_transforms.ScaleJitter((1024, 1024), antialias=True), {}),
91
92
93
94
            (
                det_transforms.RandomShortestSize(
                    min_size=(480, 512, 544, 576, 608, 640, 672, 704, 736, 768, 800), max_size=1333
                ),
95
                v2_transforms.RandomShortestSize(
96
97
98
99
100
101
102
                    min_size=(480, 512, 544, 576, 608, 640, 672, 704, 736, 768, 800), max_size=1333
                ),
                {},
            ),
        ],
    )
    def test_transform(self, t_ref, t, data_kwargs):
103
        for dp in self.make_tv_tensors(**data_kwargs):
104
105
106
107
108
109
110
111
112

            # We should use prototype transform first as reference transform performs inplace target update
            torch.manual_seed(12)
            output = t(dp)

            torch.manual_seed(12)
            expected_output = t_ref(*dp)

            assert_equal(expected_output, output)
113
114
115
116
117
118
119
120
121


seg_transforms = import_transforms_from_references("segmentation")


# We need this transform for two reasons:
# 1. transforms.RandomCrop uses a different scheme to pad images and masks of insufficient size than its name
#    counterpart in the detection references. Thus, we cannot use it with `pad_if_needed=True`
# 2. transforms.Pad only supports a fixed padding, but the segmentation datasets don't have a fixed image size.
122
class PadIfSmaller(v2_transforms.Transform):
123
124
125
    def __init__(self, size, fill=0):
        super().__init__()
        self.size = size
126
        self.fill = v2_transforms._geometry._setup_fill_arg(fill)
127
128

    def _get_params(self, sample):
Philip Meier's avatar
Philip Meier committed
129
        height, width = query_size(sample)
130
131
132
133
134
135
136
137
        padding = [0, 0, max(self.size - width, 0), max(self.size - height, 0)]
        needs_padding = any(padding)
        return dict(padding=padding, needs_padding=needs_padding)

    def _transform(self, inpt, params):
        if not params["needs_padding"]:
            return inpt

138
        fill = _get_fill(self.fill, type(inpt))
139
        return prototype_F.pad(inpt, padding=params["padding"], fill=fill)
140
141
142


class TestRefSegTransforms:
143
    def make_tv_tensors(self, supports_pil=True, image_dtype=torch.uint8):
144
        size = (256, 460)
145
146
147
148
        num_categories = 21

        conv_fns = []
        if supports_pil:
149
            conv_fns.append(to_pil_image)
150
151
152
        conv_fns.extend([torch.Tensor, lambda x: x])

        for conv_fn in conv_fns:
153
154
            tv_tensor_image = make_image(size=size, color_space="RGB", dtype=image_dtype)
            tv_tensor_mask = make_segmentation_mask(size=size, num_categories=num_categories, dtype=torch.uint8)
155

156
            dp = (conv_fn(tv_tensor_image), tv_tensor_mask)
157
            dp_ref = (
158
159
                to_pil_image(tv_tensor_image) if supports_pil else tv_tensor_image.as_subclass(torch.Tensor),
                to_pil_image(tv_tensor_mask),
160
161
162
163
164
165
166
167
168
            )

            yield dp, dp_ref

    def set_seed(self, seed=12):
        torch.manual_seed(seed)
        random.seed(seed)

    def check(self, t, t_ref, data_kwargs=None):
169
        for dp, dp_ref in self.make_tv_tensors(**data_kwargs or dict()):
170
171

            self.set_seed()
172
            actual = actual_image, actual_mask = t(dp)
173
174

            self.set_seed()
175
176
177
178
179
            expected_image, expected_mask = t_ref(*dp_ref)
            if isinstance(actual_image, torch.Tensor) and not isinstance(expected_image, torch.Tensor):
                expected_image = legacy_F.pil_to_tensor(expected_image)
            expected_mask = legacy_F.pil_to_tensor(expected_mask).squeeze(0)
            expected = (expected_image, expected_mask)
180

181
            assert_equal(actual, expected)
182
183
184
185
186
187

    @pytest.mark.parametrize(
        ("t_ref", "t", "data_kwargs"),
        [
            (
                seg_transforms.RandomHorizontalFlip(flip_prob=1.0),
188
                v2_transforms.RandomHorizontalFlip(p=1.0),
189
190
191
192
                dict(),
            ),
            (
                seg_transforms.RandomHorizontalFlip(flip_prob=0.0),
193
                v2_transforms.RandomHorizontalFlip(p=0.0),
194
195
196
197
                dict(),
            ),
            (
                seg_transforms.RandomCrop(size=480),
198
                v2_transforms.Compose(
199
                    [
200
                        PadIfSmaller(size=480, fill={tv_tensors.Mask: 255, "others": 0}),
201
                        v2_transforms.RandomCrop(size=480),
202
203
204
205
206
207
                    ]
                ),
                dict(),
            ),
            (
                seg_transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
208
                v2_transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
209
210
211
212
213
214
                dict(supports_pil=False, image_dtype=torch.float),
            ),
        ],
    )
    def test_common(self, t_ref, t, data_kwargs):
        self.check(t, t_ref, data_kwargs)