test_datapoints.py 10.7 KB
Newer Older
1
2
from copy import deepcopy

3
4
import pytest
import torch
5
from common_utils import assert_equal, make_bounding_box, make_image, make_segmentation_mask, make_video
6
7
8
from PIL import Image

from torchvision import datapoints
9
10
11


@pytest.fixture(autouse=True)
12
13
14
def restore_tensor_return_type():
    # This is for security, as we should already be restoring the default manually in each test anyway
    # (at least at the time of writing...)
15
16
    yield
    datapoints.set_return_type("Tensor")
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32


@pytest.mark.parametrize("data", [torch.rand(3, 32, 32), Image.new("RGB", (32, 32), color=123)])
def test_image_instance(data):
    image = datapoints.Image(data)
    assert isinstance(image, torch.Tensor)
    assert image.ndim == 3 and image.shape[0] == 3


@pytest.mark.parametrize("data", [torch.randint(0, 10, size=(1, 32, 32)), Image.new("L", (32, 32), color=2)])
def test_mask_instance(data):
    mask = datapoints.Mask(data)
    assert isinstance(mask, torch.Tensor)
    assert mask.ndim == 3 and mask.shape[0] == 1


33
@pytest.mark.parametrize("data", [torch.randint(0, 32, size=(5, 4)), [[0, 0, 5, 5], [2, 2, 7, 7]], [1, 2, 3, 4]])
34
35
36
37
@pytest.mark.parametrize(
    "format", ["XYXY", "CXCYWH", datapoints.BoundingBoxFormat.XYXY, datapoints.BoundingBoxFormat.XYWH]
)
def test_bbox_instance(data, format):
Philip Meier's avatar
Philip Meier committed
38
    bboxes = datapoints.BoundingBoxes(data, format=format, canvas_size=(32, 32))
39
40
41
    assert isinstance(bboxes, torch.Tensor)
    assert bboxes.ndim == 2 and bboxes.shape[1] == 4
    if isinstance(format, str):
42
        format = datapoints.BoundingBoxFormat[(format.upper())]
43
    assert bboxes.format == format
44
45


46
47
48
49
50
51
def test_bbox_dim_error():
    data_3d = [[[1, 2, 3, 4]]]
    with pytest.raises(ValueError, match="Expected a 1D or 2D tensor, got 3D"):
        datapoints.BoundingBoxes(data_3d, format="XYXY", canvas_size=(32, 32))


52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
@pytest.mark.parametrize(
    ("data", "input_requires_grad", "expected_requires_grad"),
    [
        ([[[0.0, 1.0], [0.0, 1.0]]], None, False),
        ([[[0.0, 1.0], [0.0, 1.0]]], False, False),
        ([[[0.0, 1.0], [0.0, 1.0]]], True, True),
        (torch.rand(3, 16, 16, requires_grad=False), None, False),
        (torch.rand(3, 16, 16, requires_grad=False), False, False),
        (torch.rand(3, 16, 16, requires_grad=False), True, True),
        (torch.rand(3, 16, 16, requires_grad=True), None, True),
        (torch.rand(3, 16, 16, requires_grad=True), False, False),
        (torch.rand(3, 16, 16, requires_grad=True), True, True),
    ],
)
def test_new_requires_grad(data, input_requires_grad, expected_requires_grad):
    datapoint = datapoints.Image(data, requires_grad=input_requires_grad)
    assert datapoint.requires_grad is expected_requires_grad


71
72
73
@pytest.mark.parametrize("make_input", [make_image, make_bounding_box, make_segmentation_mask, make_video])
def test_isinstance(make_input):
    assert isinstance(make_input(), torch.Tensor)
74
75
76
77
78
79
80
81
82


def test_wrapping_no_copy():
    tensor = torch.rand(3, 16, 16)
    image = datapoints.Image(tensor)

    assert image.data_ptr() == tensor.data_ptr()


83
84
85
@pytest.mark.parametrize("make_input", [make_image, make_bounding_box, make_segmentation_mask, make_video])
def test_to_wrapping(make_input):
    dp = make_input()
86

87
    dp_to = dp.to(torch.float64)
88

89
90
    assert type(dp_to) is type(dp)
    assert dp_to.dtype is torch.float64
91
92


93
@pytest.mark.parametrize("make_input", [make_image, make_bounding_box, make_segmentation_mask, make_video])
94
@pytest.mark.parametrize("return_type", ["Tensor", "datapoint"])
95
def test_to_datapoint_reference(make_input, return_type):
96
    tensor = torch.rand((3, 16, 16), dtype=torch.float64)
97
    dp = make_input()
98

99
    with datapoints.set_return_type(return_type):
100
        tensor_to = tensor.to(dp)
101

102
103
    assert type(tensor_to) is (type(dp) if return_type == "datapoint" else torch.Tensor)
    assert tensor_to.dtype is dp.dtype
104
105


106
@pytest.mark.parametrize("make_input", [make_image, make_bounding_box, make_segmentation_mask, make_video])
107
@pytest.mark.parametrize("return_type", ["Tensor", "datapoint"])
108
109
def test_clone_wrapping(make_input, return_type):
    dp = make_input()
110

111
    with datapoints.set_return_type(return_type):
112
        dp_clone = dp.clone()
113

114
115
    assert type(dp_clone) is type(dp)
    assert dp_clone.data_ptr() != dp.data_ptr()
116
117


118
@pytest.mark.parametrize("make_input", [make_image, make_bounding_box, make_segmentation_mask, make_video])
119
@pytest.mark.parametrize("return_type", ["Tensor", "datapoint"])
120
121
def test_requires_grad__wrapping(make_input, return_type):
    dp = make_input(dtype=torch.float)
122

123
    assert not dp.requires_grad
124

125
    with datapoints.set_return_type(return_type):
126
        dp_requires_grad = dp.requires_grad_(True)
127

128
129
130
    assert type(dp_requires_grad) is type(dp)
    assert dp.requires_grad
    assert dp_requires_grad.requires_grad
131
132


133
@pytest.mark.parametrize("make_input", [make_image, make_bounding_box, make_segmentation_mask, make_video])
134
@pytest.mark.parametrize("return_type", ["Tensor", "datapoint"])
135
136
def test_detach_wrapping(make_input, return_type):
    dp = make_input(dtype=torch.float).requires_grad_(True)
137

138
    with datapoints.set_return_type(return_type):
139
        dp_detached = dp.detach()
140

141
    assert type(dp_detached) is type(dp)
142
143


144
145
146
@pytest.mark.parametrize("return_type", ["Tensor", "datapoint"])
def test_force_subclass_with_metadata(return_type):
    # Sanity checks for the ops in _FORCE_TORCHFUNCTION_SUBCLASS and datapoints with metadata
147
    # Largely the same as above, we additionally check that the metadata is preserved
148
149
150
    format, canvas_size = "XYXY", (32, 32)
    bbox = datapoints.BoundingBoxes([[0, 0, 5, 5], [2, 2, 7, 7]], format=format, canvas_size=canvas_size)

151
    datapoints.set_return_type(return_type)
152
    bbox = bbox.clone()
153
154
    if return_type == "datapoint":
        assert bbox.format, bbox.canvas_size == (format, canvas_size)
155
156

    bbox = bbox.to(torch.float64)
157
158
    if return_type == "datapoint":
        assert bbox.format, bbox.canvas_size == (format, canvas_size)
159
160

    bbox = bbox.detach()
161
162
    if return_type == "datapoint":
        assert bbox.format, bbox.canvas_size == (format, canvas_size)
163
164
165

    assert not bbox.requires_grad
    bbox.requires_grad_(True)
166
167
168
    if return_type == "datapoint":
        assert bbox.format, bbox.canvas_size == (format, canvas_size)
        assert bbox.requires_grad
169
    datapoints.set_return_type("tensor")
170
171


172
@pytest.mark.parametrize("make_input", [make_image, make_bounding_box, make_segmentation_mask, make_video])
173
@pytest.mark.parametrize("return_type", ["Tensor", "datapoint"])
174
175
def test_other_op_no_wrapping(make_input, return_type):
    dp = make_input()
176

177
178
    with datapoints.set_return_type(return_type):
        # any operation besides the ones listed in _FORCE_TORCHFUNCTION_SUBCLASS will do here
179
        output = dp * 2
180

181
    assert type(output) is (type(dp) if return_type == "datapoint" else torch.Tensor)
182
183


184
@pytest.mark.parametrize("make_input", [make_image, make_bounding_box, make_segmentation_mask, make_video])
185
186
187
188
189
190
191
192
@pytest.mark.parametrize(
    "op",
    [
        lambda t: t.numpy(),
        lambda t: t.tolist(),
        lambda t: t.max(dim=-1),
    ],
)
193
194
def test_no_tensor_output_op_no_wrapping(make_input, op):
    dp = make_input()
195

196
    output = op(dp)
197

198
    assert type(output) is not type(dp)
199
200


201
@pytest.mark.parametrize("make_input", [make_image, make_bounding_box, make_segmentation_mask, make_video])
202
@pytest.mark.parametrize("return_type", ["Tensor", "datapoint"])
203
204
205
def test_inplace_op_no_wrapping(make_input, return_type):
    dp = make_input()
    original_type = type(dp)
206

207
    with datapoints.set_return_type(return_type):
208
        output = dp.add_(0)
209

210
211
    assert type(output) is (type(dp) if return_type == "datapoint" else torch.Tensor)
    assert type(dp) is original_type
212
213


214
215
216
@pytest.mark.parametrize("make_input", [make_image, make_bounding_box, make_segmentation_mask, make_video])
def test_wrap_like(make_input):
    dp = make_input()
217

218
    # any operation besides the ones listed in _FORCE_TORCHFUNCTION_SUBCLASS will do here
219
    output = dp * 2
220

221
    dp_new = type(dp).wrap_like(dp, output)
222

223
224
    assert type(dp_new) is type(dp)
    assert dp_new.data_ptr() == output.data_ptr()
225
226


227
@pytest.mark.parametrize("make_input", [make_image, make_bounding_box, make_segmentation_mask, make_video])
228
@pytest.mark.parametrize("requires_grad", [False, True])
229
230
def test_deepcopy(make_input, requires_grad):
    dp = make_input(dtype=torch.float)
231

232
    dp.requires_grad_(requires_grad)
233

234
    dp_deepcopied = deepcopy(dp)
235

236
237
238
    assert dp_deepcopied is not dp
    assert dp_deepcopied.data_ptr() != dp.data_ptr()
    assert_equal(dp_deepcopied, dp)
239

240
241
    assert type(dp_deepcopied) is type(dp)
    assert dp_deepcopied.requires_grad is requires_grad
242
243


244
@pytest.mark.parametrize("make_input", [make_image, make_bounding_box, make_segmentation_mask, make_video])
245
@pytest.mark.parametrize("return_type", ["Tensor", "datapoint"])
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
@pytest.mark.parametrize(
    "op",
    (
        lambda dp: dp + torch.rand(*dp.shape),
        lambda dp: torch.rand(*dp.shape) + dp,
        lambda dp: dp * torch.rand(*dp.shape),
        lambda dp: torch.rand(*dp.shape) * dp,
        lambda dp: dp + 3,
        lambda dp: 3 + dp,
        lambda dp: dp + dp,
        lambda dp: dp.sum(),
        lambda dp: dp.reshape(-1),
        lambda dp: dp.int(),
        lambda dp: torch.stack([dp, dp]),
        lambda dp: torch.chunk(dp, 2)[0],
        lambda dp: torch.unbind(dp)[0],
    ),
)
def test_usual_operations(make_input, return_type, op):

    dp = make_input()
    with datapoints.set_return_type(return_type):
        out = op(dp)
    assert type(out) is (type(dp) if return_type == "datapoint" else torch.Tensor)
    if isinstance(dp, datapoints.BoundingBoxes) and return_type == "datapoint":
        assert hasattr(out, "format")
        assert hasattr(out, "canvas_size")


def test_subclasses():
    img = make_image()
    masks = make_segmentation_mask()

    with pytest.raises(TypeError, match="unsupported operand"):
        img + masks


def test_set_return_type():
    img = make_image()

    assert type(img + 3) is torch.Tensor

    with datapoints.set_return_type("datapoint"):
        assert type(img + 3) is datapoints.Image
    assert type(img + 3) is torch.Tensor

    datapoints.set_return_type("datapoint")
    assert type(img + 3) is datapoints.Image

    with datapoints.set_return_type("tensor"):
        assert type(img + 3) is torch.Tensor
        with datapoints.set_return_type("datapoint"):
            assert type(img + 3) is datapoints.Image
            datapoints.set_return_type("tensor")
            assert type(img + 3) is torch.Tensor
        assert type(img + 3) is torch.Tensor
    # Exiting a context manager will restore the return type as it was prior to entering it,
    # regardless of whether the "global" datapoints.set_return_type() was called within the context manager.
    assert type(img + 3) is datapoints.Image
305

306
    datapoints.set_return_type("tensor")