test_tv_tensors.py 10.7 KB
Newer Older
1
2
from copy import deepcopy

3
4
import pytest
import torch
5
from common_utils import assert_equal, make_bounding_boxes, make_image, make_segmentation_mask, make_video
6
7
from PIL import Image

8
from torchvision import tv_tensors
9
10
11


@pytest.fixture(autouse=True)
12
13
14
def restore_tensor_return_type():
    # This is for security, as we should already be restoring the default manually in each test anyway
    # (at least at the time of writing...)
15
    yield
16
    tv_tensors.set_return_type("Tensor")
17
18
19
20


@pytest.mark.parametrize("data", [torch.rand(3, 32, 32), Image.new("RGB", (32, 32), color=123)])
def test_image_instance(data):
21
    image = tv_tensors.Image(data)
22
23
24
25
26
27
    assert isinstance(image, torch.Tensor)
    assert image.ndim == 3 and image.shape[0] == 3


@pytest.mark.parametrize("data", [torch.randint(0, 10, size=(1, 32, 32)), Image.new("L", (32, 32), color=2)])
def test_mask_instance(data):
28
    mask = tv_tensors.Mask(data)
29
30
31
32
    assert isinstance(mask, torch.Tensor)
    assert mask.ndim == 3 and mask.shape[0] == 1


33
@pytest.mark.parametrize("data", [torch.randint(0, 32, size=(5, 4)), [[0, 0, 5, 5], [2, 2, 7, 7]], [1, 2, 3, 4]])
34
@pytest.mark.parametrize(
35
    "format", ["XYXY", "CXCYWH", tv_tensors.BoundingBoxFormat.XYXY, tv_tensors.BoundingBoxFormat.XYWH]
36
37
)
def test_bbox_instance(data, format):
38
    bboxes = tv_tensors.BoundingBoxes(data, format=format, canvas_size=(32, 32))
39
40
41
    assert isinstance(bboxes, torch.Tensor)
    assert bboxes.ndim == 2 and bboxes.shape[1] == 4
    if isinstance(format, str):
42
        format = tv_tensors.BoundingBoxFormat[(format.upper())]
43
    assert bboxes.format == format
44
45


46
47
48
def test_bbox_dim_error():
    data_3d = [[[1, 2, 3, 4]]]
    with pytest.raises(ValueError, match="Expected a 1D or 2D tensor, got 3D"):
49
        tv_tensors.BoundingBoxes(data_3d, format="XYXY", canvas_size=(32, 32))
50
51


52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
@pytest.mark.parametrize(
    ("data", "input_requires_grad", "expected_requires_grad"),
    [
        ([[[0.0, 1.0], [0.0, 1.0]]], None, False),
        ([[[0.0, 1.0], [0.0, 1.0]]], False, False),
        ([[[0.0, 1.0], [0.0, 1.0]]], True, True),
        (torch.rand(3, 16, 16, requires_grad=False), None, False),
        (torch.rand(3, 16, 16, requires_grad=False), False, False),
        (torch.rand(3, 16, 16, requires_grad=False), True, True),
        (torch.rand(3, 16, 16, requires_grad=True), None, True),
        (torch.rand(3, 16, 16, requires_grad=True), False, False),
        (torch.rand(3, 16, 16, requires_grad=True), True, True),
    ],
)
def test_new_requires_grad(data, input_requires_grad, expected_requires_grad):
67
68
    tv_tensor = tv_tensors.Image(data, requires_grad=input_requires_grad)
    assert tv_tensor.requires_grad is expected_requires_grad
69
70


71
@pytest.mark.parametrize("make_input", [make_image, make_bounding_boxes, make_segmentation_mask, make_video])
72
73
def test_isinstance(make_input):
    assert isinstance(make_input(), torch.Tensor)
74
75
76
77


def test_wrapping_no_copy():
    tensor = torch.rand(3, 16, 16)
78
    image = tv_tensors.Image(tensor)
79
80
81
82

    assert image.data_ptr() == tensor.data_ptr()


83
@pytest.mark.parametrize("make_input", [make_image, make_bounding_boxes, make_segmentation_mask, make_video])
84
85
def test_to_wrapping(make_input):
    dp = make_input()
86

87
    dp_to = dp.to(torch.float64)
88

89
90
    assert type(dp_to) is type(dp)
    assert dp_to.dtype is torch.float64
91
92


93
@pytest.mark.parametrize("make_input", [make_image, make_bounding_boxes, make_segmentation_mask, make_video])
94
95
@pytest.mark.parametrize("return_type", ["Tensor", "tv_tensor"])
def test_to_tv_tensor_reference(make_input, return_type):
96
    tensor = torch.rand((3, 16, 16), dtype=torch.float64)
97
    dp = make_input()
98

99
    with tv_tensors.set_return_type(return_type):
100
        tensor_to = tensor.to(dp)
101

102
    assert type(tensor_to) is (type(dp) if return_type == "tv_tensor" else torch.Tensor)
103
    assert tensor_to.dtype is dp.dtype
104
    assert type(tensor) is torch.Tensor
105
106


107
@pytest.mark.parametrize("make_input", [make_image, make_bounding_boxes, make_segmentation_mask, make_video])
108
@pytest.mark.parametrize("return_type", ["Tensor", "tv_tensor"])
109
110
def test_clone_wrapping(make_input, return_type):
    dp = make_input()
111

112
    with tv_tensors.set_return_type(return_type):
113
        dp_clone = dp.clone()
114

115
116
    assert type(dp_clone) is type(dp)
    assert dp_clone.data_ptr() != dp.data_ptr()
117
118


119
@pytest.mark.parametrize("make_input", [make_image, make_bounding_boxes, make_segmentation_mask, make_video])
120
@pytest.mark.parametrize("return_type", ["Tensor", "tv_tensor"])
121
122
def test_requires_grad__wrapping(make_input, return_type):
    dp = make_input(dtype=torch.float)
123

124
    assert not dp.requires_grad
125

126
    with tv_tensors.set_return_type(return_type):
127
        dp_requires_grad = dp.requires_grad_(True)
128

129
130
131
    assert type(dp_requires_grad) is type(dp)
    assert dp.requires_grad
    assert dp_requires_grad.requires_grad
132
133


134
@pytest.mark.parametrize("make_input", [make_image, make_bounding_boxes, make_segmentation_mask, make_video])
135
@pytest.mark.parametrize("return_type", ["Tensor", "tv_tensor"])
136
137
def test_detach_wrapping(make_input, return_type):
    dp = make_input(dtype=torch.float).requires_grad_(True)
138

139
    with tv_tensors.set_return_type(return_type):
140
        dp_detached = dp.detach()
141

142
    assert type(dp_detached) is type(dp)
143
144


145
@pytest.mark.parametrize("return_type", ["Tensor", "tv_tensor"])
146
def test_force_subclass_with_metadata(return_type):
147
    # Sanity checks for the ops in _FORCE_TORCHFUNCTION_SUBCLASS and tv_tensors with metadata
148
    # Largely the same as above, we additionally check that the metadata is preserved
149
    format, canvas_size = "XYXY", (32, 32)
150
    bbox = tv_tensors.BoundingBoxes([[0, 0, 5, 5], [2, 2, 7, 7]], format=format, canvas_size=canvas_size)
151

152
    tv_tensors.set_return_type(return_type)
153
    bbox = bbox.clone()
154
    if return_type == "tv_tensor":
155
        assert bbox.format, bbox.canvas_size == (format, canvas_size)
156
157

    bbox = bbox.to(torch.float64)
158
    if return_type == "tv_tensor":
159
        assert bbox.format, bbox.canvas_size == (format, canvas_size)
160
161

    bbox = bbox.detach()
162
    if return_type == "tv_tensor":
163
        assert bbox.format, bbox.canvas_size == (format, canvas_size)
164
165
166

    assert not bbox.requires_grad
    bbox.requires_grad_(True)
167
    if return_type == "tv_tensor":
168
169
        assert bbox.format, bbox.canvas_size == (format, canvas_size)
        assert bbox.requires_grad
170
    tv_tensors.set_return_type("tensor")
171
172


173
@pytest.mark.parametrize("make_input", [make_image, make_bounding_boxes, make_segmentation_mask, make_video])
174
@pytest.mark.parametrize("return_type", ["Tensor", "tv_tensor"])
175
176
def test_other_op_no_wrapping(make_input, return_type):
    dp = make_input()
177

178
    with tv_tensors.set_return_type(return_type):
179
        # any operation besides the ones listed in _FORCE_TORCHFUNCTION_SUBCLASS will do here
180
        output = dp * 2
181

182
    assert type(output) is (type(dp) if return_type == "tv_tensor" else torch.Tensor)
183
184


185
@pytest.mark.parametrize("make_input", [make_image, make_bounding_boxes, make_segmentation_mask, make_video])
186
187
188
189
190
191
192
193
@pytest.mark.parametrize(
    "op",
    [
        lambda t: t.numpy(),
        lambda t: t.tolist(),
        lambda t: t.max(dim=-1),
    ],
)
194
195
def test_no_tensor_output_op_no_wrapping(make_input, op):
    dp = make_input()
196

197
    output = op(dp)
198

199
    assert type(output) is not type(dp)
200
201


202
@pytest.mark.parametrize("make_input", [make_image, make_bounding_boxes, make_segmentation_mask, make_video])
203
@pytest.mark.parametrize("return_type", ["Tensor", "tv_tensor"])
204
205
206
def test_inplace_op_no_wrapping(make_input, return_type):
    dp = make_input()
    original_type = type(dp)
207

208
    with tv_tensors.set_return_type(return_type):
209
        output = dp.add_(0)
210

211
    assert type(output) is (type(dp) if return_type == "tv_tensor" else torch.Tensor)
212
    assert type(dp) is original_type
213
214


215
@pytest.mark.parametrize("make_input", [make_image, make_bounding_boxes, make_segmentation_mask, make_video])
216
def test_wrap(make_input):
217
    dp = make_input()
218

219
    # any operation besides the ones listed in _FORCE_TORCHFUNCTION_SUBCLASS will do here
220
    output = dp * 2
221

222
    dp_new = tv_tensors.wrap(output, like=dp)
223

224
225
    assert type(dp_new) is type(dp)
    assert dp_new.data_ptr() == output.data_ptr()
226
227


228
@pytest.mark.parametrize("make_input", [make_image, make_bounding_boxes, make_segmentation_mask, make_video])
229
@pytest.mark.parametrize("requires_grad", [False, True])
230
231
def test_deepcopy(make_input, requires_grad):
    dp = make_input(dtype=torch.float)
232

233
    dp.requires_grad_(requires_grad)
234

235
    dp_deepcopied = deepcopy(dp)
236

237
238
239
    assert dp_deepcopied is not dp
    assert dp_deepcopied.data_ptr() != dp.data_ptr()
    assert_equal(dp_deepcopied, dp)
240

241
242
    assert type(dp_deepcopied) is type(dp)
    assert dp_deepcopied.requires_grad is requires_grad
243
244


245
@pytest.mark.parametrize("make_input", [make_image, make_bounding_boxes, make_segmentation_mask, make_video])
246
@pytest.mark.parametrize("return_type", ["Tensor", "tv_tensor"])
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
@pytest.mark.parametrize(
    "op",
    (
        lambda dp: dp + torch.rand(*dp.shape),
        lambda dp: torch.rand(*dp.shape) + dp,
        lambda dp: dp * torch.rand(*dp.shape),
        lambda dp: torch.rand(*dp.shape) * dp,
        lambda dp: dp + 3,
        lambda dp: 3 + dp,
        lambda dp: dp + dp,
        lambda dp: dp.sum(),
        lambda dp: dp.reshape(-1),
        lambda dp: dp.int(),
        lambda dp: torch.stack([dp, dp]),
        lambda dp: torch.chunk(dp, 2)[0],
        lambda dp: torch.unbind(dp)[0],
    ),
)
def test_usual_operations(make_input, return_type, op):

    dp = make_input()
268
    with tv_tensors.set_return_type(return_type):
269
        out = op(dp)
270
271
    assert type(out) is (type(dp) if return_type == "tv_tensor" else torch.Tensor)
    if isinstance(dp, tv_tensors.BoundingBoxes) and return_type == "tv_tensor":
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
        assert hasattr(out, "format")
        assert hasattr(out, "canvas_size")


def test_subclasses():
    img = make_image()
    masks = make_segmentation_mask()

    with pytest.raises(TypeError, match="unsupported operand"):
        img + masks


def test_set_return_type():
    img = make_image()

    assert type(img + 3) is torch.Tensor

289
290
    with tv_tensors.set_return_type("tv_tensor"):
        assert type(img + 3) is tv_tensors.Image
291
292
    assert type(img + 3) is torch.Tensor

293
294
    tv_tensors.set_return_type("tv_tensor")
    assert type(img + 3) is tv_tensors.Image
295

296
    with tv_tensors.set_return_type("tensor"):
297
        assert type(img + 3) is torch.Tensor
298
299
300
        with tv_tensors.set_return_type("tv_tensor"):
            assert type(img + 3) is tv_tensors.Image
            tv_tensors.set_return_type("tensor")
301
302
303
            assert type(img + 3) is torch.Tensor
        assert type(img + 3) is torch.Tensor
    # Exiting a context manager will restore the return type as it was prior to entering it,
304
305
    # regardless of whether the "global" tv_tensors.set_return_type() was called within the context manager.
    assert type(img + 3) is tv_tensors.Image
306

307
    tv_tensors.set_return_type("tensor")