test_datapoints.py 10.7 KB
Newer Older
1
2
from copy import deepcopy

3
4
import pytest
import torch
5
from common_utils import assert_equal, make_bounding_box, make_image, make_segmentation_mask, make_video
6
7
8
from PIL import Image

from torchvision import datapoints
9
10
11


@pytest.fixture(autouse=True)
12
13
14
def restore_tensor_return_type():
    # This is for security, as we should already be restoring the default manually in each test anyway
    # (at least at the time of writing...)
15
16
    yield
    datapoints.set_return_type("Tensor")
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32


@pytest.mark.parametrize("data", [torch.rand(3, 32, 32), Image.new("RGB", (32, 32), color=123)])
def test_image_instance(data):
    image = datapoints.Image(data)
    assert isinstance(image, torch.Tensor)
    assert image.ndim == 3 and image.shape[0] == 3


@pytest.mark.parametrize("data", [torch.randint(0, 10, size=(1, 32, 32)), Image.new("L", (32, 32), color=2)])
def test_mask_instance(data):
    mask = datapoints.Mask(data)
    assert isinstance(mask, torch.Tensor)
    assert mask.ndim == 3 and mask.shape[0] == 1


33
@pytest.mark.parametrize("data", [torch.randint(0, 32, size=(5, 4)), [[0, 0, 5, 5], [2, 2, 7, 7]], [1, 2, 3, 4]])
34
35
36
37
@pytest.mark.parametrize(
    "format", ["XYXY", "CXCYWH", datapoints.BoundingBoxFormat.XYXY, datapoints.BoundingBoxFormat.XYWH]
)
def test_bbox_instance(data, format):
Philip Meier's avatar
Philip Meier committed
38
    bboxes = datapoints.BoundingBoxes(data, format=format, canvas_size=(32, 32))
39
40
41
    assert isinstance(bboxes, torch.Tensor)
    assert bboxes.ndim == 2 and bboxes.shape[1] == 4
    if isinstance(format, str):
42
        format = datapoints.BoundingBoxFormat[(format.upper())]
43
    assert bboxes.format == format
44
45


46
47
48
49
50
51
def test_bbox_dim_error():
    data_3d = [[[1, 2, 3, 4]]]
    with pytest.raises(ValueError, match="Expected a 1D or 2D tensor, got 3D"):
        datapoints.BoundingBoxes(data_3d, format="XYXY", canvas_size=(32, 32))


52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
@pytest.mark.parametrize(
    ("data", "input_requires_grad", "expected_requires_grad"),
    [
        ([[[0.0, 1.0], [0.0, 1.0]]], None, False),
        ([[[0.0, 1.0], [0.0, 1.0]]], False, False),
        ([[[0.0, 1.0], [0.0, 1.0]]], True, True),
        (torch.rand(3, 16, 16, requires_grad=False), None, False),
        (torch.rand(3, 16, 16, requires_grad=False), False, False),
        (torch.rand(3, 16, 16, requires_grad=False), True, True),
        (torch.rand(3, 16, 16, requires_grad=True), None, True),
        (torch.rand(3, 16, 16, requires_grad=True), False, False),
        (torch.rand(3, 16, 16, requires_grad=True), True, True),
    ],
)
def test_new_requires_grad(data, input_requires_grad, expected_requires_grad):
    datapoint = datapoints.Image(data, requires_grad=input_requires_grad)
    assert datapoint.requires_grad is expected_requires_grad


71
72
73
@pytest.mark.parametrize("make_input", [make_image, make_bounding_box, make_segmentation_mask, make_video])
def test_isinstance(make_input):
    assert isinstance(make_input(), torch.Tensor)
74
75
76
77
78
79
80
81
82


def test_wrapping_no_copy():
    tensor = torch.rand(3, 16, 16)
    image = datapoints.Image(tensor)

    assert image.data_ptr() == tensor.data_ptr()


83
84
85
@pytest.mark.parametrize("make_input", [make_image, make_bounding_box, make_segmentation_mask, make_video])
def test_to_wrapping(make_input):
    dp = make_input()
86

87
    dp_to = dp.to(torch.float64)
88

89
90
    assert type(dp_to) is type(dp)
    assert dp_to.dtype is torch.float64
91
92


93
@pytest.mark.parametrize("make_input", [make_image, make_bounding_box, make_segmentation_mask, make_video])
94
@pytest.mark.parametrize("return_type", ["Tensor", "datapoint"])
95
def test_to_datapoint_reference(make_input, return_type):
96
    tensor = torch.rand((3, 16, 16), dtype=torch.float64)
97
    dp = make_input()
98

99
    with datapoints.set_return_type(return_type):
100
        tensor_to = tensor.to(dp)
101

102
103
    assert type(tensor_to) is (type(dp) if return_type == "datapoint" else torch.Tensor)
    assert tensor_to.dtype is dp.dtype
104
    assert type(tensor) is torch.Tensor
105
106


107
@pytest.mark.parametrize("make_input", [make_image, make_bounding_box, make_segmentation_mask, make_video])
108
@pytest.mark.parametrize("return_type", ["Tensor", "datapoint"])
109
110
def test_clone_wrapping(make_input, return_type):
    dp = make_input()
111

112
    with datapoints.set_return_type(return_type):
113
        dp_clone = dp.clone()
114

115
116
    assert type(dp_clone) is type(dp)
    assert dp_clone.data_ptr() != dp.data_ptr()
117
118


119
@pytest.mark.parametrize("make_input", [make_image, make_bounding_box, make_segmentation_mask, make_video])
120
@pytest.mark.parametrize("return_type", ["Tensor", "datapoint"])
121
122
def test_requires_grad__wrapping(make_input, return_type):
    dp = make_input(dtype=torch.float)
123

124
    assert not dp.requires_grad
125

126
    with datapoints.set_return_type(return_type):
127
        dp_requires_grad = dp.requires_grad_(True)
128

129
130
131
    assert type(dp_requires_grad) is type(dp)
    assert dp.requires_grad
    assert dp_requires_grad.requires_grad
132
133


134
@pytest.mark.parametrize("make_input", [make_image, make_bounding_box, make_segmentation_mask, make_video])
135
@pytest.mark.parametrize("return_type", ["Tensor", "datapoint"])
136
137
def test_detach_wrapping(make_input, return_type):
    dp = make_input(dtype=torch.float).requires_grad_(True)
138

139
    with datapoints.set_return_type(return_type):
140
        dp_detached = dp.detach()
141

142
    assert type(dp_detached) is type(dp)
143
144


145
146
147
@pytest.mark.parametrize("return_type", ["Tensor", "datapoint"])
def test_force_subclass_with_metadata(return_type):
    # Sanity checks for the ops in _FORCE_TORCHFUNCTION_SUBCLASS and datapoints with metadata
148
    # Largely the same as above, we additionally check that the metadata is preserved
149
150
151
    format, canvas_size = "XYXY", (32, 32)
    bbox = datapoints.BoundingBoxes([[0, 0, 5, 5], [2, 2, 7, 7]], format=format, canvas_size=canvas_size)

152
    datapoints.set_return_type(return_type)
153
    bbox = bbox.clone()
154
155
    if return_type == "datapoint":
        assert bbox.format, bbox.canvas_size == (format, canvas_size)
156
157

    bbox = bbox.to(torch.float64)
158
159
    if return_type == "datapoint":
        assert bbox.format, bbox.canvas_size == (format, canvas_size)
160
161

    bbox = bbox.detach()
162
163
    if return_type == "datapoint":
        assert bbox.format, bbox.canvas_size == (format, canvas_size)
164
165
166

    assert not bbox.requires_grad
    bbox.requires_grad_(True)
167
168
169
    if return_type == "datapoint":
        assert bbox.format, bbox.canvas_size == (format, canvas_size)
        assert bbox.requires_grad
170
    datapoints.set_return_type("tensor")
171
172


173
@pytest.mark.parametrize("make_input", [make_image, make_bounding_box, make_segmentation_mask, make_video])
174
@pytest.mark.parametrize("return_type", ["Tensor", "datapoint"])
175
176
def test_other_op_no_wrapping(make_input, return_type):
    dp = make_input()
177

178
179
    with datapoints.set_return_type(return_type):
        # any operation besides the ones listed in _FORCE_TORCHFUNCTION_SUBCLASS will do here
180
        output = dp * 2
181

182
    assert type(output) is (type(dp) if return_type == "datapoint" else torch.Tensor)
183
184


185
@pytest.mark.parametrize("make_input", [make_image, make_bounding_box, make_segmentation_mask, make_video])
186
187
188
189
190
191
192
193
@pytest.mark.parametrize(
    "op",
    [
        lambda t: t.numpy(),
        lambda t: t.tolist(),
        lambda t: t.max(dim=-1),
    ],
)
194
195
def test_no_tensor_output_op_no_wrapping(make_input, op):
    dp = make_input()
196

197
    output = op(dp)
198

199
    assert type(output) is not type(dp)
200
201


202
@pytest.mark.parametrize("make_input", [make_image, make_bounding_box, make_segmentation_mask, make_video])
203
@pytest.mark.parametrize("return_type", ["Tensor", "datapoint"])
204
205
206
def test_inplace_op_no_wrapping(make_input, return_type):
    dp = make_input()
    original_type = type(dp)
207

208
    with datapoints.set_return_type(return_type):
209
        output = dp.add_(0)
210

211
212
    assert type(output) is (type(dp) if return_type == "datapoint" else torch.Tensor)
    assert type(dp) is original_type
213
214


215
216
217
@pytest.mark.parametrize("make_input", [make_image, make_bounding_box, make_segmentation_mask, make_video])
def test_wrap_like(make_input):
    dp = make_input()
218

219
    # any operation besides the ones listed in _FORCE_TORCHFUNCTION_SUBCLASS will do here
220
    output = dp * 2
221

222
    dp_new = type(dp).wrap_like(dp, output)
223

224
225
    assert type(dp_new) is type(dp)
    assert dp_new.data_ptr() == output.data_ptr()
226
227


228
@pytest.mark.parametrize("make_input", [make_image, make_bounding_box, make_segmentation_mask, make_video])
229
@pytest.mark.parametrize("requires_grad", [False, True])
230
231
def test_deepcopy(make_input, requires_grad):
    dp = make_input(dtype=torch.float)
232

233
    dp.requires_grad_(requires_grad)
234

235
    dp_deepcopied = deepcopy(dp)
236

237
238
239
    assert dp_deepcopied is not dp
    assert dp_deepcopied.data_ptr() != dp.data_ptr()
    assert_equal(dp_deepcopied, dp)
240

241
242
    assert type(dp_deepcopied) is type(dp)
    assert dp_deepcopied.requires_grad is requires_grad
243
244


245
@pytest.mark.parametrize("make_input", [make_image, make_bounding_box, make_segmentation_mask, make_video])
246
@pytest.mark.parametrize("return_type", ["Tensor", "datapoint"])
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
@pytest.mark.parametrize(
    "op",
    (
        lambda dp: dp + torch.rand(*dp.shape),
        lambda dp: torch.rand(*dp.shape) + dp,
        lambda dp: dp * torch.rand(*dp.shape),
        lambda dp: torch.rand(*dp.shape) * dp,
        lambda dp: dp + 3,
        lambda dp: 3 + dp,
        lambda dp: dp + dp,
        lambda dp: dp.sum(),
        lambda dp: dp.reshape(-1),
        lambda dp: dp.int(),
        lambda dp: torch.stack([dp, dp]),
        lambda dp: torch.chunk(dp, 2)[0],
        lambda dp: torch.unbind(dp)[0],
    ),
)
def test_usual_operations(make_input, return_type, op):

    dp = make_input()
    with datapoints.set_return_type(return_type):
        out = op(dp)
    assert type(out) is (type(dp) if return_type == "datapoint" else torch.Tensor)
    if isinstance(dp, datapoints.BoundingBoxes) and return_type == "datapoint":
        assert hasattr(out, "format")
        assert hasattr(out, "canvas_size")


def test_subclasses():
    img = make_image()
    masks = make_segmentation_mask()

    with pytest.raises(TypeError, match="unsupported operand"):
        img + masks


def test_set_return_type():
    img = make_image()

    assert type(img + 3) is torch.Tensor

    with datapoints.set_return_type("datapoint"):
        assert type(img + 3) is datapoints.Image
    assert type(img + 3) is torch.Tensor

    datapoints.set_return_type("datapoint")
    assert type(img + 3) is datapoints.Image

    with datapoints.set_return_type("tensor"):
        assert type(img + 3) is torch.Tensor
        with datapoints.set_return_type("datapoint"):
            assert type(img + 3) is datapoints.Image
            datapoints.set_return_type("tensor")
            assert type(img + 3) is torch.Tensor
        assert type(img + 3) is torch.Tensor
    # Exiting a context manager will restore the return type as it was prior to entering it,
    # regardless of whether the "global" datapoints.set_return_type() was called within the context manager.
    assert type(img + 3) is datapoints.Image
306

307
    datapoints.set_return_type("tensor")