test_utils.py 17.2 KB
Newer Older
1
import os
2
import re
Francisco Massa's avatar
Francisco Massa committed
3
import sys
4
import tempfile
5
from io import BytesIO
6
7
8
9

import numpy as np
import pytest
import torch
10
import torchvision.transforms.functional as F
11
import torchvision.utils as utils
12
from common_utils import assert_equal, cpu_and_cuda
13
from PIL import __version__ as PILLOW_VERSION, Image, ImageColor
14
from torchvision.transforms.v2.functional import to_dtype
Nicolas Hug's avatar
Nicolas Hug committed
15
16


17
PILLOW_VERSION = tuple(int(x) for x in PILLOW_VERSION.split("."))
18

19
boxes = torch.tensor([[0, 0, 20, 20], [0, 0, 0, 0], [10, 15, 30, 35], [23, 35, 93, 95]], dtype=torch.float)
20

21
22
keypoints = torch.tensor([[[10, 10], [5, 5], [2, 2]], [[20, 20], [30, 30], [3, 3]]], dtype=torch.float)

23

24
25
26
27
28
def test_make_grid_not_inplace():
    t = torch.rand(5, 3, 10, 10)
    t_clone = t.clone()

    utils.make_grid(t, normalize=False)
29
    assert_equal(t, t_clone, msg="make_grid modified tensor in-place")
30
31

    utils.make_grid(t, normalize=True, scale_each=False)
32
    assert_equal(t, t_clone, msg="make_grid modified tensor in-place")
33
34

    utils.make_grid(t, normalize=True, scale_each=True)
35
    assert_equal(t, t_clone, msg="make_grid modified tensor in-place")
36
37
38
39
40
41
42
43
44
45
46
47
48


def test_normalize_in_make_grid():
    t = torch.rand(5, 3, 10, 10) * 255
    norm_max = torch.tensor(1.0)
    norm_min = torch.tensor(0.0)

    grid = utils.make_grid(t, normalize=True)
    grid_max = torch.max(grid)
    grid_min = torch.min(grid)

    # Rounding the result to one decimal for comparison
    n_digits = 1
49
50
    rounded_grid_max = torch.round(grid_max * 10**n_digits) / (10**n_digits)
    rounded_grid_min = torch.round(grid_min * 10**n_digits) / (10**n_digits)
51

52
53
    assert_equal(norm_max, rounded_grid_max, msg="Normalized max is not equal to 1")
    assert_equal(norm_min, rounded_grid_min, msg="Normalized min is not equal to 0")
54
55


56
@pytest.mark.skipif(sys.platform in ("win32", "cygwin"), reason="temporarily disabled on Windows")
57
def test_save_image():
58
    with tempfile.NamedTemporaryFile(suffix=".png") as f:
59
60
        t = torch.rand(2, 3, 64, 64)
        utils.save_image(t, f.name)
61
        assert os.path.exists(f.name), "The image is not present after save"
62

63

64
@pytest.mark.skipif(sys.platform in ("win32", "cygwin"), reason="temporarily disabled on Windows")
65
def test_save_image_single_pixel():
66
    with tempfile.NamedTemporaryFile(suffix=".png") as f:
67
68
        t = torch.rand(1, 3, 1, 1)
        utils.save_image(t, f.name)
69
        assert os.path.exists(f.name), "The pixel image is not present after save"
70
71


72
@pytest.mark.skipif(sys.platform in ("win32", "cygwin"), reason="temporarily disabled on Windows")
73
def test_save_image_file_object():
74
    with tempfile.NamedTemporaryFile(suffix=".png") as f:
75
76
77
78
        t = torch.rand(2, 3, 64, 64)
        utils.save_image(t, f.name)
        img_orig = Image.open(f.name)
        fp = BytesIO()
79
        utils.save_image(t, fp, format="png")
80
        img_bytes = Image.open(fp)
81
        assert_equal(F.pil_to_tensor(img_orig), F.pil_to_tensor(img_bytes), msg="Image not stored in file object")
82
83


84
@pytest.mark.skipif(sys.platform in ("win32", "cygwin"), reason="temporarily disabled on Windows")
85
def test_save_image_single_pixel_file_object():
86
    with tempfile.NamedTemporaryFile(suffix=".png") as f:
87
88
89
90
        t = torch.rand(1, 3, 1, 1)
        utils.save_image(t, f.name)
        img_orig = Image.open(f.name)
        fp = BytesIO()
91
        utils.save_image(t, fp, format="png")
92
        img_bytes = Image.open(fp)
93
        assert_equal(F.pil_to_tensor(img_orig), F.pil_to_tensor(img_bytes), msg="Image not stored in file object")
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108


def test_draw_boxes():
    img = torch.full((3, 100, 100), 255, dtype=torch.uint8)
    img_cp = img.clone()
    boxes_cp = boxes.clone()
    labels = ["a", "b", "c", "d"]
    colors = ["green", "#FF00FF", (0, 255, 0), "red"]
    result = utils.draw_bounding_boxes(img, boxes, labels=labels, colors=colors, fill=True)

    path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets", "fakedata", "draw_boxes_util.png")
    if not os.path.exists(path):
        res = Image.fromarray(result.permute(1, 2, 0).contiguous().numpy())
        res.save(path)

109
    if PILLOW_VERSION >= (10, 1):
110
        # The reference image is only valid for new PIL versions
111
        expected = torch.as_tensor(np.array(Image.open(path))).permute(2, 0, 1)
112
        assert_equal(result, expected)
113

114
115
116
117
118
    # Check if modification is not in place
    assert_equal(boxes, boxes_cp)
    assert_equal(img, img_cp)


119
@pytest.mark.parametrize("colors", [None, ["red", "blue", "#FF00FF", (1, 34, 122)], "red", "#FF00FF", (1, 34, 122)])
120
121
122
123
def test_draw_boxes_colors(colors):
    img = torch.full((3, 100, 100), 0, dtype=torch.uint8)
    utils.draw_bounding_boxes(img, boxes, fill=False, width=7, colors=colors)

124
125
126
    with pytest.raises(ValueError, match="Number of colors must be equal or larger than the number of objects"):
        utils.draw_bounding_boxes(image=img, boxes=boxes, colors=[])

127

128
129
130
131
def test_draw_boxes_vanilla():
    img = torch.full((3, 100, 100), 0, dtype=torch.uint8)
    img_cp = img.clone()
    boxes_cp = boxes.clone()
132
    result = utils.draw_bounding_boxes(img, boxes, fill=False, width=7, colors="white")
133
134
135
136
137
138
139
140
141
142
143
144
145

    path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets", "fakedata", "draw_boxes_vanilla.png")
    if not os.path.exists(path):
        res = Image.fromarray(result.permute(1, 2, 0).contiguous().numpy())
        res.save(path)

    expected = torch.as_tensor(np.array(Image.open(path))).permute(2, 0, 1)
    assert_equal(result, expected)
    # Check if modification is not in place
    assert_equal(boxes, boxes_cp)
    assert_equal(img, img_cp)


146
147
148
149
150
151
152
def test_draw_boxes_grayscale():
    img = torch.full((1, 4, 4), fill_value=255, dtype=torch.uint8)
    boxes = torch.tensor([[0, 0, 3, 3]], dtype=torch.int64)
    bboxed_img = utils.draw_bounding_boxes(image=img, boxes=boxes, colors=["#1BBC9B"])
    assert bboxed_img.size(0) == 3


153
154
155
156
def test_draw_invalid_boxes():
    img_tp = ((1, 1, 1), (1, 2, 3))
    img_wrong1 = torch.full((3, 5, 5), 255, dtype=torch.float)
    img_wrong2 = torch.full((1, 3, 5, 5), 255, dtype=torch.uint8)
157
    img_correct = torch.zeros((3, 10, 10), dtype=torch.uint8)
158
    boxes = torch.tensor([[0, 0, 20, 20], [0, 0, 0, 0], [10, 15, 30, 35], [23, 35, 93, 95]], dtype=torch.float)
159
    boxes_wrong = torch.tensor([[10, 10, 4, 5], [30, 20, 10, 5]], dtype=torch.float)
160
161
162
    labels_wrong = ["one", "two"]
    colors_wrong = ["pink", "blue"]

163
164
165
166
167
168
    with pytest.raises(TypeError, match="Tensor expected"):
        utils.draw_bounding_boxes(img_tp, boxes)
    with pytest.raises(ValueError, match="Tensor uint8 expected"):
        utils.draw_bounding_boxes(img_wrong1, boxes)
    with pytest.raises(ValueError, match="Pass individual images, not batches"):
        utils.draw_bounding_boxes(img_wrong2, boxes)
169
170
    with pytest.raises(ValueError, match="Only grayscale and RGB images are supported"):
        utils.draw_bounding_boxes(img_wrong2[0][:2], boxes)
171
172
173
174
    with pytest.raises(ValueError, match="Number of boxes"):
        utils.draw_bounding_boxes(img_correct, boxes, labels_wrong)
    with pytest.raises(ValueError, match="Number of colors"):
        utils.draw_bounding_boxes(img_correct, boxes, colors=colors_wrong)
175
176
    with pytest.raises(ValueError, match="Boxes need to be in"):
        utils.draw_bounding_boxes(img_correct, boxes_wrong)
177

178

179
180
181
182
183
184
185
def test_draw_boxes_warning():
    img = torch.full((3, 100, 100), 255, dtype=torch.uint8)

    with pytest.warns(UserWarning, match=re.escape("Argument 'font_size' will be ignored since 'font' is not set.")):
        utils.draw_bounding_boxes(img, boxes, font_size=11)


186
187
188
189
190
def test_draw_no_boxes():
    img = torch.full((3, 100, 100), 0, dtype=torch.uint8)
    boxes = torch.full((0, 4), 0, dtype=torch.float)
    with pytest.warns(UserWarning, match=re.escape("boxes doesn't contain any box. No box was drawn")):
        res = utils.draw_bounding_boxes(img, boxes)
191
        # Check that the function didn't change the image
192
193
194
        assert res.eq(img).all()


195
196
197
198
@pytest.mark.parametrize(
    "colors",
    [
        None,
199
200
201
        "blue",
        "#FF00FF",
        (1, 34, 122),
202
203
204
205
206
        ["red", "blue"],
        ["#FF00FF", (1, 34, 122)],
    ],
)
@pytest.mark.parametrize("alpha", (0, 0.5, 0.7, 1))
207
208
@pytest.mark.parametrize("device", cpu_and_cuda())
def test_draw_segmentation_masks(colors, alpha, device):
209
210
211
    """This test makes sure that masks draw their corresponding color where they should"""
    num_masks, h, w = 2, 100, 100
    dtype = torch.uint8
212
213
    img = torch.randint(0, 256, size=(3, h, w), dtype=dtype, device=device)
    masks = torch.randint(0, 2, (num_masks, h, w), dtype=torch.bool, device=device)
214
215
216

    # For testing we enforce that there's no overlap between the masks. The
    # current behaviour is that the last mask's color will take priority when
217
    # masks overlap, but this makes testing slightly harder, so we don't really
218
219
220
221
222
223
224
225
226
227
    # care
    overlap = masks[0] & masks[1]
    masks[:, overlap] = False

    out = utils.draw_segmentation_masks(img, masks, colors=colors, alpha=alpha)
    assert out.dtype == dtype
    assert out is not img

    # Make sure the image didn't change where there's no mask
    masked_pixels = masks[0] | masks[1]
228
    assert_equal(img[:, ~masked_pixels], out[:, ~masked_pixels])
229
230
231

    if colors is None:
        colors = utils._generate_color_palette(num_masks)
232
233
    elif isinstance(colors, str) or isinstance(colors, tuple):
        colors = [colors]
234
235
236
237
238

    # Make sure each mask draws with its own color
    for mask, color in zip(masks, colors):
        if isinstance(color, str):
            color = ImageColor.getrgb(color)
239
        color = torch.tensor(color, dtype=dtype, device=device)
240
241
242
243
244
245

        if alpha == 1:
            assert (out[:, mask] == color[:, None]).all()
        elif alpha == 0:
            assert (out[:, mask] == img[:, mask]).all()

246
247
        interpolated_color = (img[:, mask] * (1 - alpha) + color[:, None] * alpha).to(dtype)
        torch.testing.assert_close(out[:, mask], interpolated_color, rtol=0.0, atol=1.0)
248
249


250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
def test_draw_segmentation_masks_dtypes():
    num_masks, h, w = 2, 100, 100

    masks = torch.randint(0, 2, (num_masks, h, w), dtype=torch.bool)

    img_uint8 = torch.randint(0, 256, size=(3, h, w), dtype=torch.uint8)
    out_uint8 = utils.draw_segmentation_masks(img_uint8, masks)

    assert img_uint8 is not out_uint8
    assert out_uint8.dtype == torch.uint8

    img_float = to_dtype(img_uint8, torch.float, scale=True)
    out_float = utils.draw_segmentation_masks(img_float, masks)

    assert img_float is not out_float
    assert out_float.is_floating_point()

    torch.testing.assert_close(out_uint8, to_dtype(out_float, torch.uint8, scale=True), rtol=0, atol=1)


270
271
@pytest.mark.parametrize("device", cpu_and_cuda())
def test_draw_segmentation_masks_errors(device):
272
273
    h, w = 10, 10

274
275
    masks = torch.randint(0, 2, size=(h, w), dtype=torch.bool, device=device)
    img = torch.randint(0, 256, size=(3, h, w), dtype=torch.uint8, device=device)
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296

    with pytest.raises(TypeError, match="The image must be a tensor"):
        utils.draw_segmentation_masks(image="Not A Tensor Image", masks=masks)
    with pytest.raises(ValueError, match="The image dtype must be"):
        img_bad_dtype = torch.randint(0, 256, size=(3, h, w), dtype=torch.int64)
        utils.draw_segmentation_masks(image=img_bad_dtype, masks=masks)
    with pytest.raises(ValueError, match="Pass individual images, not batches"):
        batch = torch.randint(0, 256, size=(10, 3, h, w), dtype=torch.uint8)
        utils.draw_segmentation_masks(image=batch, masks=masks)
    with pytest.raises(ValueError, match="Pass an RGB image"):
        one_channel = torch.randint(0, 256, size=(1, h, w), dtype=torch.uint8)
        utils.draw_segmentation_masks(image=one_channel, masks=masks)
    with pytest.raises(ValueError, match="The masks must be of dtype bool"):
        masks_bad_dtype = torch.randint(0, 2, size=(h, w), dtype=torch.float)
        utils.draw_segmentation_masks(image=img, masks=masks_bad_dtype)
    with pytest.raises(ValueError, match="masks must be of shape"):
        masks_bad_shape = torch.randint(0, 2, size=(3, 2, h, w), dtype=torch.bool)
        utils.draw_segmentation_masks(image=img, masks=masks_bad_shape)
    with pytest.raises(ValueError, match="must have the same height and width"):
        masks_bad_shape = torch.randint(0, 2, size=(h + 4, w), dtype=torch.bool)
        utils.draw_segmentation_masks(image=img, masks=masks_bad_shape)
297
    with pytest.raises(ValueError, match="Number of colors must be equal or larger than the number of objects"):
298
        utils.draw_segmentation_masks(image=img, masks=masks, colors=[])
299
    with pytest.raises(ValueError, match="`colors` must be a tuple or a string, or a list thereof"):
300
        bad_colors = np.array(["red", "blue"])  # should be a list
301
        utils.draw_segmentation_masks(image=img, masks=masks, colors=bad_colors)
302
    with pytest.raises(ValueError, match="If passed as tuple, colors should be an RGB triplet"):
303
        bad_colors = ("red", "blue")  # should be a list
304
        utils.draw_segmentation_masks(image=img, masks=masks, colors=bad_colors)
305

306

307
308
309
310
@pytest.mark.parametrize("device", cpu_and_cuda())
def test_draw_no_segmention_mask(device):
    img = torch.full((3, 100, 100), 0, dtype=torch.uint8, device=device)
    masks = torch.full((0, 100, 100), 0, dtype=torch.bool, device=device)
311
312
    with pytest.warns(UserWarning, match=re.escape("masks doesn't contain any mask. No mask was drawn")):
        res = utils.draw_segmentation_masks(img, masks)
313
        # Check that the function didn't change the image
314
315
316
        assert res.eq(img).all()


317
318
319
320
321
322
def test_draw_keypoints_vanilla():
    # Keypoints is declared on top as global variable
    keypoints_cp = keypoints.clone()

    img = torch.full((3, 100, 100), 0, dtype=torch.uint8)
    img_cp = img.clone()
323
324
325
326
327
328
329
330
    result = utils.draw_keypoints(
        img,
        keypoints,
        colors="red",
        connectivity=[
            (0, 1),
        ],
    )
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
    path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets", "fakedata", "draw_keypoint_vanilla.png")
    if not os.path.exists(path):
        res = Image.fromarray(result.permute(1, 2, 0).contiguous().numpy())
        res.save(path)

    expected = torch.as_tensor(np.array(Image.open(path))).permute(2, 0, 1)
    assert_equal(result, expected)
    # Check that keypoints are not modified inplace
    assert_equal(keypoints, keypoints_cp)
    # Check that image is not modified in place
    assert_equal(img, img_cp)


@pytest.mark.parametrize("colors", ["red", "#FF00FF", (1, 34, 122)])
def test_draw_keypoints_colored(colors):
    # Keypoints is declared on top as global variable
    keypoints_cp = keypoints.clone()

    img = torch.full((3, 100, 100), 0, dtype=torch.uint8)
    img_cp = img.clone()
351
352
353
354
355
356
357
358
    result = utils.draw_keypoints(
        img,
        keypoints,
        colors=colors,
        connectivity=[
            (0, 1),
        ],
    )
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
    assert result.size(0) == 3
    assert_equal(keypoints, keypoints_cp)
    assert_equal(img, img_cp)


def test_draw_keypoints_errors():
    h, w = 10, 10
    img = torch.full((3, 100, 100), 0, dtype=torch.uint8)

    with pytest.raises(TypeError, match="The image must be a tensor"):
        utils.draw_keypoints(image="Not A Tensor Image", keypoints=keypoints)
    with pytest.raises(ValueError, match="The image dtype must be"):
        img_bad_dtype = torch.full((3, h, w), 0, dtype=torch.int64)
        utils.draw_keypoints(image=img_bad_dtype, keypoints=keypoints)
    with pytest.raises(ValueError, match="Pass individual images, not batches"):
        batch = torch.randint(0, 256, size=(10, 3, h, w), dtype=torch.uint8)
        utils.draw_keypoints(image=batch, keypoints=keypoints)
    with pytest.raises(ValueError, match="Pass an RGB image"):
        one_channel = torch.randint(0, 256, size=(1, h, w), dtype=torch.uint8)
        utils.draw_keypoints(image=one_channel, keypoints=keypoints)
    with pytest.raises(ValueError, match="keypoints must be of shape"):
        invalid_keypoints = torch.tensor([[10, 10, 10, 10], [5, 6, 7, 8]], dtype=torch.float)
        utils.draw_keypoints(image=img, keypoints=invalid_keypoints)


384
385
@pytest.mark.parametrize("batch", (True, False))
def test_flow_to_image(batch):
386
387
388
389
390
    h, w = 100, 100
    flow = torch.meshgrid(torch.arange(h), torch.arange(w), indexing="ij")
    flow = torch.stack(flow[::-1], dim=0).float()
    flow[0] -= h / 2
    flow[1] -= w / 2
391
392
393
394

    if batch:
        flow = torch.stack([flow, flow])

395
    img = utils.flow_to_image(flow)
396
397
    assert img.shape == (2, 3, h, w) if batch else (3, h, w)

398
    path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets", "expected_flow.pt")
399
    expected_img = torch.load(path, map_location="cpu", weights_only=True)
400

401
402
403
404
    if batch:
        expected_img = torch.stack([expected_img, expected_img])

    assert_equal(expected_img, img)
405
406


407
408
409
410
411
412
413
414
415
416
417
418
419
@pytest.mark.parametrize(
    "input_flow, match",
    (
        (torch.full((3, 10, 10), 0, dtype=torch.float), "Input flow should have shape"),
        (torch.full((5, 3, 10, 10), 0, dtype=torch.float), "Input flow should have shape"),
        (torch.full((2, 10), 0, dtype=torch.float), "Input flow should have shape"),
        (torch.full((5, 2, 10), 0, dtype=torch.float), "Input flow should have shape"),
        (torch.full((2, 10, 30), 0, dtype=torch.int), "Flow should be of dtype torch.float"),
    ),
)
def test_flow_to_image_errors(input_flow, match):
    with pytest.raises(ValueError, match=match):
        utils.flow_to_image(flow=input_flow)
420
421


422
423
if __name__ == "__main__":
    pytest.main([__file__])