test_utils.py 16.2 KB
Newer Older
1
import os
2
import re
Francisco Massa's avatar
Francisco Massa committed
3
import sys
4
import tempfile
5
from io import BytesIO
6
7
8
9

import numpy as np
import pytest
import torch
10
import torchvision.transforms.functional as F
11
import torchvision.utils as utils
12
from common_utils import assert_equal
13
from PIL import __version__ as PILLOW_VERSION, Image, ImageColor
Nicolas Hug's avatar
Nicolas Hug committed
14
15


16
PILLOW_VERSION = tuple(int(x) for x in PILLOW_VERSION.split("."))
17

18
boxes = torch.tensor([[0, 0, 20, 20], [0, 0, 0, 0], [10, 15, 30, 35], [23, 35, 93, 95]], dtype=torch.float)
19

20
21
keypoints = torch.tensor([[[10, 10], [5, 5], [2, 2]], [[20, 20], [30, 30], [3, 3]]], dtype=torch.float)

22

23
24
25
26
27
def test_make_grid_not_inplace():
    t = torch.rand(5, 3, 10, 10)
    t_clone = t.clone()

    utils.make_grid(t, normalize=False)
28
    assert_equal(t, t_clone, msg="make_grid modified tensor in-place")
29
30

    utils.make_grid(t, normalize=True, scale_each=False)
31
    assert_equal(t, t_clone, msg="make_grid modified tensor in-place")
32
33

    utils.make_grid(t, normalize=True, scale_each=True)
34
    assert_equal(t, t_clone, msg="make_grid modified tensor in-place")
35
36
37
38
39
40
41
42
43
44
45
46
47


def test_normalize_in_make_grid():
    t = torch.rand(5, 3, 10, 10) * 255
    norm_max = torch.tensor(1.0)
    norm_min = torch.tensor(0.0)

    grid = utils.make_grid(t, normalize=True)
    grid_max = torch.max(grid)
    grid_min = torch.min(grid)

    # Rounding the result to one decimal for comparison
    n_digits = 1
48
49
    rounded_grid_max = torch.round(grid_max * 10**n_digits) / (10**n_digits)
    rounded_grid_min = torch.round(grid_min * 10**n_digits) / (10**n_digits)
50

51
52
    assert_equal(norm_max, rounded_grid_max, msg="Normalized max is not equal to 1")
    assert_equal(norm_min, rounded_grid_min, msg="Normalized min is not equal to 0")
53
54


55
@pytest.mark.skipif(sys.platform in ("win32", "cygwin"), reason="temporarily disabled on Windows")
56
def test_save_image():
57
    with tempfile.NamedTemporaryFile(suffix=".png") as f:
58
59
        t = torch.rand(2, 3, 64, 64)
        utils.save_image(t, f.name)
60
        assert os.path.exists(f.name), "The image is not present after save"
61

62

63
@pytest.mark.skipif(sys.platform in ("win32", "cygwin"), reason="temporarily disabled on Windows")
64
def test_save_image_single_pixel():
65
    with tempfile.NamedTemporaryFile(suffix=".png") as f:
66
67
        t = torch.rand(1, 3, 1, 1)
        utils.save_image(t, f.name)
68
        assert os.path.exists(f.name), "The pixel image is not present after save"
69
70


71
@pytest.mark.skipif(sys.platform in ("win32", "cygwin"), reason="temporarily disabled on Windows")
72
def test_save_image_file_object():
73
    with tempfile.NamedTemporaryFile(suffix=".png") as f:
74
75
76
77
        t = torch.rand(2, 3, 64, 64)
        utils.save_image(t, f.name)
        img_orig = Image.open(f.name)
        fp = BytesIO()
78
        utils.save_image(t, fp, format="png")
79
        img_bytes = Image.open(fp)
80
        assert_equal(F.pil_to_tensor(img_orig), F.pil_to_tensor(img_bytes), msg="Image not stored in file object")
81
82


83
@pytest.mark.skipif(sys.platform in ("win32", "cygwin"), reason="temporarily disabled on Windows")
84
def test_save_image_single_pixel_file_object():
85
    with tempfile.NamedTemporaryFile(suffix=".png") as f:
86
87
88
89
        t = torch.rand(1, 3, 1, 1)
        utils.save_image(t, f.name)
        img_orig = Image.open(f.name)
        fp = BytesIO()
90
        utils.save_image(t, fp, format="png")
91
        img_bytes = Image.open(fp)
92
        assert_equal(F.pil_to_tensor(img_orig), F.pil_to_tensor(img_bytes), msg="Image not stored in file object")
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109


def test_draw_boxes():
    img = torch.full((3, 100, 100), 255, dtype=torch.uint8)
    img_cp = img.clone()
    boxes_cp = boxes.clone()
    labels = ["a", "b", "c", "d"]
    colors = ["green", "#FF00FF", (0, 255, 0), "red"]
    result = utils.draw_bounding_boxes(img, boxes, labels=labels, colors=colors, fill=True)

    path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets", "fakedata", "draw_boxes_util.png")
    if not os.path.exists(path):
        res = Image.fromarray(result.permute(1, 2, 0).contiguous().numpy())
        res.save(path)

    if PILLOW_VERSION >= (8, 2):
        # The reference image is only valid for new PIL versions
110
        expected = torch.as_tensor(np.array(Image.open(path))).permute(2, 0, 1)
111
        assert_equal(result, expected)
112

113
114
115
116
117
    # Check if modification is not in place
    assert_equal(boxes, boxes_cp)
    assert_equal(img, img_cp)


118
@pytest.mark.parametrize("colors", [None, ["red", "blue", "#FF00FF", (1, 34, 122)], "red", "#FF00FF", (1, 34, 122)])
119
120
121
122
def test_draw_boxes_colors(colors):
    img = torch.full((3, 100, 100), 0, dtype=torch.uint8)
    utils.draw_bounding_boxes(img, boxes, fill=False, width=7, colors=colors)

123
124
125
    with pytest.raises(ValueError, match="Number of colors must be equal or larger than the number of objects"):
        utils.draw_bounding_boxes(image=img, boxes=boxes, colors=[])

126

127
128
129
130
def test_draw_boxes_vanilla():
    img = torch.full((3, 100, 100), 0, dtype=torch.uint8)
    img_cp = img.clone()
    boxes_cp = boxes.clone()
131
    result = utils.draw_bounding_boxes(img, boxes, fill=False, width=7, colors="white")
132
133
134
135
136
137
138
139
140
141
142
143
144

    path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets", "fakedata", "draw_boxes_vanilla.png")
    if not os.path.exists(path):
        res = Image.fromarray(result.permute(1, 2, 0).contiguous().numpy())
        res.save(path)

    expected = torch.as_tensor(np.array(Image.open(path))).permute(2, 0, 1)
    assert_equal(result, expected)
    # Check if modification is not in place
    assert_equal(boxes, boxes_cp)
    assert_equal(img, img_cp)


145
146
147
148
149
150
151
def test_draw_boxes_grayscale():
    img = torch.full((1, 4, 4), fill_value=255, dtype=torch.uint8)
    boxes = torch.tensor([[0, 0, 3, 3]], dtype=torch.int64)
    bboxed_img = utils.draw_bounding_boxes(image=img, boxes=boxes, colors=["#1BBC9B"])
    assert bboxed_img.size(0) == 3


152
153
154
155
def test_draw_invalid_boxes():
    img_tp = ((1, 1, 1), (1, 2, 3))
    img_wrong1 = torch.full((3, 5, 5), 255, dtype=torch.float)
    img_wrong2 = torch.full((1, 3, 5, 5), 255, dtype=torch.uint8)
156
    img_correct = torch.zeros((3, 10, 10), dtype=torch.uint8)
157
    boxes = torch.tensor([[0, 0, 20, 20], [0, 0, 0, 0], [10, 15, 30, 35], [23, 35, 93, 95]], dtype=torch.float)
158
    boxes_wrong = torch.tensor([[10, 10, 4, 5], [30, 20, 10, 5]], dtype=torch.float)
159
160
161
    labels_wrong = ["one", "two"]
    colors_wrong = ["pink", "blue"]

162
163
164
165
166
167
    with pytest.raises(TypeError, match="Tensor expected"):
        utils.draw_bounding_boxes(img_tp, boxes)
    with pytest.raises(ValueError, match="Tensor uint8 expected"):
        utils.draw_bounding_boxes(img_wrong1, boxes)
    with pytest.raises(ValueError, match="Pass individual images, not batches"):
        utils.draw_bounding_boxes(img_wrong2, boxes)
168
169
    with pytest.raises(ValueError, match="Only grayscale and RGB images are supported"):
        utils.draw_bounding_boxes(img_wrong2[0][:2], boxes)
170
171
172
173
    with pytest.raises(ValueError, match="Number of boxes"):
        utils.draw_bounding_boxes(img_correct, boxes, labels_wrong)
    with pytest.raises(ValueError, match="Number of colors"):
        utils.draw_bounding_boxes(img_correct, boxes, colors=colors_wrong)
174
175
    with pytest.raises(ValueError, match="Boxes need to be in"):
        utils.draw_bounding_boxes(img_correct, boxes_wrong)
176

177

178
179
180
181
182
183
184
def test_draw_boxes_warning():
    img = torch.full((3, 100, 100), 255, dtype=torch.uint8)

    with pytest.warns(UserWarning, match=re.escape("Argument 'font_size' will be ignored since 'font' is not set.")):
        utils.draw_bounding_boxes(img, boxes, font_size=11)


185
186
187
188
189
def test_draw_no_boxes():
    img = torch.full((3, 100, 100), 0, dtype=torch.uint8)
    boxes = torch.full((0, 4), 0, dtype=torch.float)
    with pytest.warns(UserWarning, match=re.escape("boxes doesn't contain any box. No box was drawn")):
        res = utils.draw_bounding_boxes(img, boxes)
190
        # Check that the function didn't change the image
191
192
193
        assert res.eq(img).all()


194
195
196
197
@pytest.mark.parametrize(
    "colors",
    [
        None,
198
199
200
        "blue",
        "#FF00FF",
        (1, 34, 122),
201
202
203
204
205
        ["red", "blue"],
        ["#FF00FF", (1, 34, 122)],
    ],
)
@pytest.mark.parametrize("alpha", (0, 0.5, 0.7, 1))
206
207
208
209
210
211
212
213
214
def test_draw_segmentation_masks(colors, alpha):
    """This test makes sure that masks draw their corresponding color where they should"""
    num_masks, h, w = 2, 100, 100
    dtype = torch.uint8
    img = torch.randint(0, 256, size=(3, h, w), dtype=dtype)
    masks = torch.randint(0, 2, (num_masks, h, w), dtype=torch.bool)

    # For testing we enforce that there's no overlap between the masks. The
    # current behaviour is that the last mask's color will take priority when
215
    # masks overlap, but this makes testing slightly harder, so we don't really
216
217
218
219
220
221
222
223
224
225
    # care
    overlap = masks[0] & masks[1]
    masks[:, overlap] = False

    out = utils.draw_segmentation_masks(img, masks, colors=colors, alpha=alpha)
    assert out.dtype == dtype
    assert out is not img

    # Make sure the image didn't change where there's no mask
    masked_pixels = masks[0] | masks[1]
226
    assert_equal(img[:, ~masked_pixels], out[:, ~masked_pixels])
227
228
229

    if colors is None:
        colors = utils._generate_color_palette(num_masks)
230
231
    elif isinstance(colors, str) or isinstance(colors, tuple):
        colors = [colors]
232
233
234
235
236
237
238
239
240
241
242
243

    # Make sure each mask draws with its own color
    for mask, color in zip(masks, colors):
        if isinstance(color, str):
            color = ImageColor.getrgb(color)
        color = torch.tensor(color, dtype=dtype)

        if alpha == 1:
            assert (out[:, mask] == color[:, None]).all()
        elif alpha == 0:
            assert (out[:, mask] == img[:, mask]).all()

244
245
        interpolated_color = (img[:, mask] * (1 - alpha) + color[:, None] * alpha).to(dtype)
        torch.testing.assert_close(out[:, mask], interpolated_color, rtol=0.0, atol=1.0)
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273


def test_draw_segmentation_masks_errors():
    h, w = 10, 10

    masks = torch.randint(0, 2, size=(h, w), dtype=torch.bool)
    img = torch.randint(0, 256, size=(3, h, w), dtype=torch.uint8)

    with pytest.raises(TypeError, match="The image must be a tensor"):
        utils.draw_segmentation_masks(image="Not A Tensor Image", masks=masks)
    with pytest.raises(ValueError, match="The image dtype must be"):
        img_bad_dtype = torch.randint(0, 256, size=(3, h, w), dtype=torch.int64)
        utils.draw_segmentation_masks(image=img_bad_dtype, masks=masks)
    with pytest.raises(ValueError, match="Pass individual images, not batches"):
        batch = torch.randint(0, 256, size=(10, 3, h, w), dtype=torch.uint8)
        utils.draw_segmentation_masks(image=batch, masks=masks)
    with pytest.raises(ValueError, match="Pass an RGB image"):
        one_channel = torch.randint(0, 256, size=(1, h, w), dtype=torch.uint8)
        utils.draw_segmentation_masks(image=one_channel, masks=masks)
    with pytest.raises(ValueError, match="The masks must be of dtype bool"):
        masks_bad_dtype = torch.randint(0, 2, size=(h, w), dtype=torch.float)
        utils.draw_segmentation_masks(image=img, masks=masks_bad_dtype)
    with pytest.raises(ValueError, match="masks must be of shape"):
        masks_bad_shape = torch.randint(0, 2, size=(3, 2, h, w), dtype=torch.bool)
        utils.draw_segmentation_masks(image=img, masks=masks_bad_shape)
    with pytest.raises(ValueError, match="must have the same height and width"):
        masks_bad_shape = torch.randint(0, 2, size=(h + 4, w), dtype=torch.bool)
        utils.draw_segmentation_masks(image=img, masks=masks_bad_shape)
274
    with pytest.raises(ValueError, match="Number of colors must be equal or larger than the number of objects"):
275
        utils.draw_segmentation_masks(image=img, masks=masks, colors=[])
276
    with pytest.raises(ValueError, match="`colors` must be a tuple or a string, or a list thereof"):
277
        bad_colors = np.array(["red", "blue"])  # should be a list
278
        utils.draw_segmentation_masks(image=img, masks=masks, colors=bad_colors)
279
    with pytest.raises(ValueError, match="If passed as tuple, colors should be an RGB triplet"):
280
        bad_colors = ("red", "blue")  # should be a list
281
        utils.draw_segmentation_masks(image=img, masks=masks, colors=bad_colors)
282

283

284
285
286
287
288
def test_draw_no_segmention_mask():
    img = torch.full((3, 100, 100), 0, dtype=torch.uint8)
    masks = torch.full((0, 100, 100), 0, dtype=torch.bool)
    with pytest.warns(UserWarning, match=re.escape("masks doesn't contain any mask. No mask was drawn")):
        res = utils.draw_segmentation_masks(img, masks)
289
        # Check that the function didn't change the image
290
291
292
        assert res.eq(img).all()


293
294
295
296
297
298
def test_draw_keypoints_vanilla():
    # Keypoints is declared on top as global variable
    keypoints_cp = keypoints.clone()

    img = torch.full((3, 100, 100), 0, dtype=torch.uint8)
    img_cp = img.clone()
299
300
301
302
303
304
305
306
    result = utils.draw_keypoints(
        img,
        keypoints,
        colors="red",
        connectivity=[
            (0, 1),
        ],
    )
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
    path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets", "fakedata", "draw_keypoint_vanilla.png")
    if not os.path.exists(path):
        res = Image.fromarray(result.permute(1, 2, 0).contiguous().numpy())
        res.save(path)

    expected = torch.as_tensor(np.array(Image.open(path))).permute(2, 0, 1)
    assert_equal(result, expected)
    # Check that keypoints are not modified inplace
    assert_equal(keypoints, keypoints_cp)
    # Check that image is not modified in place
    assert_equal(img, img_cp)


@pytest.mark.parametrize("colors", ["red", "#FF00FF", (1, 34, 122)])
def test_draw_keypoints_colored(colors):
    # Keypoints is declared on top as global variable
    keypoints_cp = keypoints.clone()

    img = torch.full((3, 100, 100), 0, dtype=torch.uint8)
    img_cp = img.clone()
327
328
329
330
331
332
333
334
    result = utils.draw_keypoints(
        img,
        keypoints,
        colors=colors,
        connectivity=[
            (0, 1),
        ],
    )
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
    assert result.size(0) == 3
    assert_equal(keypoints, keypoints_cp)
    assert_equal(img, img_cp)


def test_draw_keypoints_errors():
    h, w = 10, 10
    img = torch.full((3, 100, 100), 0, dtype=torch.uint8)

    with pytest.raises(TypeError, match="The image must be a tensor"):
        utils.draw_keypoints(image="Not A Tensor Image", keypoints=keypoints)
    with pytest.raises(ValueError, match="The image dtype must be"):
        img_bad_dtype = torch.full((3, h, w), 0, dtype=torch.int64)
        utils.draw_keypoints(image=img_bad_dtype, keypoints=keypoints)
    with pytest.raises(ValueError, match="Pass individual images, not batches"):
        batch = torch.randint(0, 256, size=(10, 3, h, w), dtype=torch.uint8)
        utils.draw_keypoints(image=batch, keypoints=keypoints)
    with pytest.raises(ValueError, match="Pass an RGB image"):
        one_channel = torch.randint(0, 256, size=(1, h, w), dtype=torch.uint8)
        utils.draw_keypoints(image=one_channel, keypoints=keypoints)
    with pytest.raises(ValueError, match="keypoints must be of shape"):
        invalid_keypoints = torch.tensor([[10, 10, 10, 10], [5, 6, 7, 8]], dtype=torch.float)
        utils.draw_keypoints(image=img, keypoints=invalid_keypoints)


360
361
@pytest.mark.parametrize("batch", (True, False))
def test_flow_to_image(batch):
362
363
364
365
366
    h, w = 100, 100
    flow = torch.meshgrid(torch.arange(h), torch.arange(w), indexing="ij")
    flow = torch.stack(flow[::-1], dim=0).float()
    flow[0] -= h / 2
    flow[1] -= w / 2
367
368
369
370

    if batch:
        flow = torch.stack([flow, flow])

371
    img = utils.flow_to_image(flow)
372
373
    assert img.shape == (2, 3, h, w) if batch else (3, h, w)

374
375
376
    path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets", "expected_flow.pt")
    expected_img = torch.load(path, map_location="cpu")

377
378
379
380
    if batch:
        expected_img = torch.stack([expected_img, expected_img])

    assert_equal(expected_img, img)
381
382


383
384
385
386
387
388
389
390
391
392
393
394
395
@pytest.mark.parametrize(
    "input_flow, match",
    (
        (torch.full((3, 10, 10), 0, dtype=torch.float), "Input flow should have shape"),
        (torch.full((5, 3, 10, 10), 0, dtype=torch.float), "Input flow should have shape"),
        (torch.full((2, 10), 0, dtype=torch.float), "Input flow should have shape"),
        (torch.full((5, 2, 10), 0, dtype=torch.float), "Input flow should have shape"),
        (torch.full((2, 10, 30), 0, dtype=torch.int), "Flow should be of dtype torch.float"),
    ),
)
def test_flow_to_image_errors(input_flow, match):
    with pytest.raises(ValueError, match=match):
        utils.flow_to_image(flow=input_flow)
396
397


398
399
if __name__ == "__main__":
    pytest.main([__file__])