test_utils.py 9.29 KB
Newer Older
1
import numpy as np
2
import os
Francisco Massa's avatar
Francisco Massa committed
3
import sys
4
import tempfile
5
6
7
import torch
import torchvision.utils as utils
import unittest
8
9
from io import BytesIO
import torchvision.transforms.functional as F
Nicolas Hug's avatar
Nicolas Hug committed
10
11
12
13
from PIL import Image, __version__ as PILLOW_VERSION


PILLOW_VERSION = tuple(int(x) for x in PILLOW_VERSION.split('.'))
14

15
16
17
boxes = torch.tensor([[0, 0, 20, 20], [0, 0, 0, 0],
                     [10, 15, 30, 35], [23, 35, 93, 95]], dtype=torch.float)

18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
masks = torch.tensor([
    [
        [-2.2799, -2.2799, -2.2799, -2.2799, -2.2799],
        [5.0914, 5.0914, 5.0914, 5.0914, 5.0914],
        [-2.2799, -2.2799, -2.2799, -2.2799, -2.2799],
        [-2.2799, -2.2799, -2.2799, -2.2799, -2.2799],
        [-2.2799, -2.2799, -2.2799, -2.2799, -2.2799]
    ],
    [
        [5.0914, 5.0914, 5.0914, 5.0914, 5.0914],
        [-2.2799, -2.2799, -2.2799, -2.2799, -2.2799],
        [5.0914, 5.0914, 5.0914, 5.0914, 5.0914],
        [5.0914, 5.0914, 5.0914, 5.0914, 5.0914],
        [-1.4541, -1.4541, -1.4541, -1.4541, -1.4541]
    ],
    [
        [-1.4541, -1.4541, -1.4541, -1.4541, -1.4541],
        [-1.4541, -1.4541, -1.4541, -1.4541, -1.4541],
        [-1.4541, -1.4541, -1.4541, -1.4541, -1.4541],
        [-1.4541, -1.4541, -1.4541, -1.4541, -1.4541],
        [5.0914, 5.0914, 5.0914, 5.0914, 5.0914],
    ]
], dtype=torch.float)

42
43
44
45
46
47
48
49

class Tester(unittest.TestCase):

    def test_make_grid_not_inplace(self):
        t = torch.rand(5, 3, 10, 10)
        t_clone = t.clone()

        utils.make_grid(t, normalize=False)
50
        self.assertTrue(torch.equal(t, t_clone), 'make_grid modified tensor in-place')
51
52

        utils.make_grid(t, normalize=True, scale_each=False)
53
        self.assertTrue(torch.equal(t, t_clone), 'make_grid modified tensor in-place')
54
55

        utils.make_grid(t, normalize=True, scale_each=True)
56
        self.assertTrue(torch.equal(t, t_clone), 'make_grid modified tensor in-place')
57

58
59
60
61
62
63
64
65
66
67
68
69
70
71
    def test_normalize_in_make_grid(self):
        t = torch.rand(5, 3, 10, 10) * 255
        norm_max = torch.tensor(1.0)
        norm_min = torch.tensor(0.0)

        grid = utils.make_grid(t, normalize=True)
        grid_max = torch.max(grid)
        grid_min = torch.min(grid)

        # Rounding the result to one decimal for comparison
        n_digits = 1
        rounded_grid_max = torch.round(grid_max * 10 ** n_digits) / (10 ** n_digits)
        rounded_grid_min = torch.round(grid_min * 10 ** n_digits) / (10 ** n_digits)

72
73
        self.assertTrue(torch.equal(norm_max, rounded_grid_max), 'Normalized max is not equal to 1')
        self.assertTrue(torch.equal(norm_min, rounded_grid_min), 'Normalized min is not equal to 0')
74

Nicolas Hug's avatar
Nicolas Hug committed
75
    @unittest.skipIf(sys.platform in ('win32', 'cygwin'), 'temporarily disabled on Windows')
76
    def test_save_image(self):
77
78
79
        with tempfile.NamedTemporaryFile(suffix='.png') as f:
            t = torch.rand(2, 3, 64, 64)
            utils.save_image(t, f.name)
80
            self.assertTrue(os.path.exists(f.name), 'The image is not present after save')
81

Nicolas Hug's avatar
Nicolas Hug committed
82
    @unittest.skipIf(sys.platform in ('win32', 'cygwin'), 'temporarily disabled on Windows')
83
84
85
86
    def test_save_image_single_pixel(self):
        with tempfile.NamedTemporaryFile(suffix='.png') as f:
            t = torch.rand(1, 3, 1, 1)
            utils.save_image(t, f.name)
87
            self.assertTrue(os.path.exists(f.name), 'The pixel image is not present after save')
88

Nicolas Hug's avatar
Nicolas Hug committed
89
    @unittest.skipIf(sys.platform in ('win32', 'cygwin'), 'temporarily disabled on Windows')
90
91
92
93
94
95
96
97
    def test_save_image_file_object(self):
        with tempfile.NamedTemporaryFile(suffix='.png') as f:
            t = torch.rand(2, 3, 64, 64)
            utils.save_image(t, f.name)
            img_orig = Image.open(f.name)
            fp = BytesIO()
            utils.save_image(t, fp, format='png')
            img_bytes = Image.open(fp)
98
99
            self.assertTrue(torch.equal(F.to_tensor(img_orig), F.to_tensor(img_bytes)),
                            'Image not stored in file object')
100

Nicolas Hug's avatar
Nicolas Hug committed
101
    @unittest.skipIf(sys.platform in ('win32', 'cygwin'), 'temporarily disabled on Windows')
102
103
104
105
106
107
108
109
    def test_save_image_single_pixel_file_object(self):
        with tempfile.NamedTemporaryFile(suffix='.png') as f:
            t = torch.rand(1, 3, 1, 1)
            utils.save_image(t, f.name)
            img_orig = Image.open(f.name)
            fp = BytesIO()
            utils.save_image(t, fp, format='png')
            img_bytes = Image.open(fp)
110
111
            self.assertTrue(torch.equal(F.to_tensor(img_orig), F.to_tensor(img_bytes)),
                            'Pixel Image not stored in file object')
112

113
114
    def test_draw_boxes(self):
        img = torch.full((3, 100, 100), 255, dtype=torch.uint8)
115
116
        img_cp = img.clone()
        boxes_cp = boxes.clone()
117
118
        labels = ["a", "b", "c", "d"]
        colors = ["green", "#FF00FF", (0, 255, 0), "red"]
119
        result = utils.draw_bounding_boxes(img, boxes, labels=labels, colors=colors, fill=True)
120
121
122

        path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets", "fakedata", "draw_boxes_util.png")
        if not os.path.exists(path):
123
124
            res = Image.fromarray(result.permute(1, 2, 0).contiguous().numpy())
            res.save(path)
125

Nicolas Hug's avatar
Nicolas Hug committed
126
127
128
129
130
        if PILLOW_VERSION >= (8, 2):
            # The reference image is only valid for new PIL versions
            expected = torch.as_tensor(np.array(Image.open(path))).permute(2, 0, 1)
            self.assertTrue(torch.equal(result, expected))

131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
        # Check if modification is not in place
        self.assertTrue(torch.all(torch.eq(boxes, boxes_cp)).item())
        self.assertTrue(torch.all(torch.eq(img, img_cp)).item())

    def test_draw_boxes_vanilla(self):
        img = torch.full((3, 100, 100), 0, dtype=torch.uint8)
        img_cp = img.clone()
        boxes_cp = boxes.clone()
        result = utils.draw_bounding_boxes(img, boxes, fill=False, width=7)

        path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets", "fakedata", "draw_boxes_vanilla.png")
        if not os.path.exists(path):
            res = Image.fromarray(result.permute(1, 2, 0).contiguous().numpy())
            res.save(path)

        expected = torch.as_tensor(np.array(Image.open(path))).permute(2, 0, 1)
        self.assertTrue(torch.equal(result, expected))
        # Check if modification is not in place
        self.assertTrue(torch.all(torch.eq(boxes, boxes_cp)).item())
        self.assertTrue(torch.all(torch.eq(img, img_cp)).item())

    def test_draw_invalid_boxes(self):
        img_tp = ((1, 1, 1), (1, 2, 3))
        img_wrong1 = torch.full((3, 5, 5), 255, dtype=torch.float)
        img_wrong2 = torch.full((1, 3, 5, 5), 255, dtype=torch.uint8)
        boxes = torch.tensor([[0, 0, 20, 20], [0, 0, 0, 0],
                             [10, 15, 30, 35], [23, 35, 93, 95]], dtype=torch.float)
        self.assertRaises(TypeError, utils.draw_bounding_boxes, img_tp, boxes)
        self.assertRaises(ValueError, utils.draw_bounding_boxes, img_wrong1, boxes)
        self.assertRaises(ValueError, utils.draw_bounding_boxes, img_wrong2, boxes)
161

162
163
    def test_draw_segmentation_masks_colors(self):
        img = torch.full((3, 5, 5), 255, dtype=torch.uint8)
164
165
        img_cp = img.clone()
        masks_cp = masks.clone()
166
167
168
169
170
171
172
173
174
175
176
177
        colors = ["#FF00FF", (0, 255, 0), "red"]
        result = utils.draw_segmentation_masks(img, masks, colors=colors)

        path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets",
                            "fakedata", "draw_segm_masks_colors_util.png")

        if not os.path.exists(path):
            res = Image.fromarray(result.permute(1, 2, 0).contiguous().numpy())
            res.save(path)

        expected = torch.as_tensor(np.array(Image.open(path))).permute(2, 0, 1)
        self.assertTrue(torch.equal(result, expected))
178
179
180
        # Check if modification is not in place
        self.assertTrue(torch.all(torch.eq(img, img_cp)).item())
        self.assertTrue(torch.all(torch.eq(masks, masks_cp)).item())
181
182
183

    def test_draw_segmentation_masks_no_colors(self):
        img = torch.full((3, 20, 20), 255, dtype=torch.uint8)
184
185
        img_cp = img.clone()
        masks_cp = masks.clone()
186
187
188
189
190
191
192
193
194
195
196
        result = utils.draw_segmentation_masks(img, masks, colors=None)

        path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets",
                            "fakedata", "draw_segm_masks_no_colors_util.png")

        if not os.path.exists(path):
            res = Image.fromarray(result.permute(1, 2, 0).contiguous().numpy())
            res.save(path)

        expected = torch.as_tensor(np.array(Image.open(path))).permute(2, 0, 1)
        self.assertTrue(torch.equal(result, expected))
197
198
199
200
201
202
203
204
205
206
207
208
209
210
        # Check if modification is not in place
        self.assertTrue(torch.all(torch.eq(img, img_cp)).item())
        self.assertTrue(torch.all(torch.eq(masks, masks_cp)).item())

    def test_draw_invalid_masks(self):
        img_tp = ((1, 1, 1), (1, 2, 3))
        img_wrong1 = torch.full((3, 5, 5), 255, dtype=torch.float)
        img_wrong2 = torch.full((1, 3, 5, 5), 255, dtype=torch.uint8)
        img_wrong3 = torch.full((4, 5, 5), 255, dtype=torch.uint8)

        self.assertRaises(TypeError, utils.draw_segmentation_masks, img_tp, masks)
        self.assertRaises(ValueError, utils.draw_segmentation_masks, img_wrong1, masks)
        self.assertRaises(ValueError, utils.draw_segmentation_masks, img_wrong2, masks)
        self.assertRaises(ValueError, utils.draw_segmentation_masks, img_wrong3, masks)
211

212
213
214

if __name__ == '__main__':
    unittest.main()