test_utils.py 9.08 KB
Newer Older
1
import numpy as np
2
import os
Francisco Massa's avatar
Francisco Massa committed
3
import sys
4
import tempfile
5
6
7
import torch
import torchvision.utils as utils
import unittest
8
9
10
from io import BytesIO
import torchvision.transforms.functional as F
from PIL import Image
11

12
13
14
boxes = torch.tensor([[0, 0, 20, 20], [0, 0, 0, 0],
                     [10, 15, 30, 35], [23, 35, 93, 95]], dtype=torch.float)

15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
masks = torch.tensor([
    [
        [-2.2799, -2.2799, -2.2799, -2.2799, -2.2799],
        [5.0914, 5.0914, 5.0914, 5.0914, 5.0914],
        [-2.2799, -2.2799, -2.2799, -2.2799, -2.2799],
        [-2.2799, -2.2799, -2.2799, -2.2799, -2.2799],
        [-2.2799, -2.2799, -2.2799, -2.2799, -2.2799]
    ],
    [
        [5.0914, 5.0914, 5.0914, 5.0914, 5.0914],
        [-2.2799, -2.2799, -2.2799, -2.2799, -2.2799],
        [5.0914, 5.0914, 5.0914, 5.0914, 5.0914],
        [5.0914, 5.0914, 5.0914, 5.0914, 5.0914],
        [-1.4541, -1.4541, -1.4541, -1.4541, -1.4541]
    ],
    [
        [-1.4541, -1.4541, -1.4541, -1.4541, -1.4541],
        [-1.4541, -1.4541, -1.4541, -1.4541, -1.4541],
        [-1.4541, -1.4541, -1.4541, -1.4541, -1.4541],
        [-1.4541, -1.4541, -1.4541, -1.4541, -1.4541],
        [5.0914, 5.0914, 5.0914, 5.0914, 5.0914],
    ]
], dtype=torch.float)

39
40
41
42
43
44
45
46

class Tester(unittest.TestCase):

    def test_make_grid_not_inplace(self):
        t = torch.rand(5, 3, 10, 10)
        t_clone = t.clone()

        utils.make_grid(t, normalize=False)
47
        self.assertTrue(torch.equal(t, t_clone), 'make_grid modified tensor in-place')
48
49

        utils.make_grid(t, normalize=True, scale_each=False)
50
        self.assertTrue(torch.equal(t, t_clone), 'make_grid modified tensor in-place')
51
52

        utils.make_grid(t, normalize=True, scale_each=True)
53
        self.assertTrue(torch.equal(t, t_clone), 'make_grid modified tensor in-place')
54

55
56
57
58
59
60
61
62
63
64
65
66
67
68
    def test_normalize_in_make_grid(self):
        t = torch.rand(5, 3, 10, 10) * 255
        norm_max = torch.tensor(1.0)
        norm_min = torch.tensor(0.0)

        grid = utils.make_grid(t, normalize=True)
        grid_max = torch.max(grid)
        grid_min = torch.min(grid)

        # Rounding the result to one decimal for comparison
        n_digits = 1
        rounded_grid_max = torch.round(grid_max * 10 ** n_digits) / (10 ** n_digits)
        rounded_grid_min = torch.round(grid_min * 10 ** n_digits) / (10 ** n_digits)

69
70
        self.assertTrue(torch.equal(norm_max, rounded_grid_max), 'Normalized max is not equal to 1')
        self.assertTrue(torch.equal(norm_min, rounded_grid_min), 'Normalized min is not equal to 0')
71

Nicolas Hug's avatar
Nicolas Hug committed
72
    @unittest.skipIf(sys.platform in ('win32', 'cygwin'), 'temporarily disabled on Windows')
73
    def test_save_image(self):
74
75
76
        with tempfile.NamedTemporaryFile(suffix='.png') as f:
            t = torch.rand(2, 3, 64, 64)
            utils.save_image(t, f.name)
77
            self.assertTrue(os.path.exists(f.name), 'The image is not present after save')
78

Nicolas Hug's avatar
Nicolas Hug committed
79
    @unittest.skipIf(sys.platform in ('win32', 'cygwin'), 'temporarily disabled on Windows')
80
81
82
83
    def test_save_image_single_pixel(self):
        with tempfile.NamedTemporaryFile(suffix='.png') as f:
            t = torch.rand(1, 3, 1, 1)
            utils.save_image(t, f.name)
84
            self.assertTrue(os.path.exists(f.name), 'The pixel image is not present after save')
85

Nicolas Hug's avatar
Nicolas Hug committed
86
    @unittest.skipIf(sys.platform in ('win32', 'cygwin'), 'temporarily disabled on Windows')
87
88
89
90
91
92
93
94
    def test_save_image_file_object(self):
        with tempfile.NamedTemporaryFile(suffix='.png') as f:
            t = torch.rand(2, 3, 64, 64)
            utils.save_image(t, f.name)
            img_orig = Image.open(f.name)
            fp = BytesIO()
            utils.save_image(t, fp, format='png')
            img_bytes = Image.open(fp)
95
96
            self.assertTrue(torch.equal(F.to_tensor(img_orig), F.to_tensor(img_bytes)),
                            'Image not stored in file object')
97

Nicolas Hug's avatar
Nicolas Hug committed
98
    @unittest.skipIf(sys.platform in ('win32', 'cygwin'), 'temporarily disabled on Windows')
99
100
101
102
103
104
105
106
    def test_save_image_single_pixel_file_object(self):
        with tempfile.NamedTemporaryFile(suffix='.png') as f:
            t = torch.rand(1, 3, 1, 1)
            utils.save_image(t, f.name)
            img_orig = Image.open(f.name)
            fp = BytesIO()
            utils.save_image(t, fp, format='png')
            img_bytes = Image.open(fp)
107
108
            self.assertTrue(torch.equal(F.to_tensor(img_orig), F.to_tensor(img_bytes)),
                            'Pixel Image not stored in file object')
109

110
111
    def test_draw_boxes(self):
        img = torch.full((3, 100, 100), 255, dtype=torch.uint8)
112
113
        img_cp = img.clone()
        boxes_cp = boxes.clone()
114
115
        labels = ["a", "b", "c", "d"]
        colors = ["green", "#FF00FF", (0, 255, 0), "red"]
116
        result = utils.draw_bounding_boxes(img, boxes, labels=labels, colors=colors, fill=True)
117
118
119

        path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets", "fakedata", "draw_boxes_util.png")
        if not os.path.exists(path):
120
121
            res = Image.fromarray(result.permute(1, 2, 0).contiguous().numpy())
            res.save(path)
122

123
        expected = torch.as_tensor(np.array(Image.open(path))).permute(2, 0, 1)
124
        self.assertTrue(torch.equal(result, expected))
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
        # Check if modification is not in place
        self.assertTrue(torch.all(torch.eq(boxes, boxes_cp)).item())
        self.assertTrue(torch.all(torch.eq(img, img_cp)).item())

    def test_draw_boxes_vanilla(self):
        img = torch.full((3, 100, 100), 0, dtype=torch.uint8)
        img_cp = img.clone()
        boxes_cp = boxes.clone()
        result = utils.draw_bounding_boxes(img, boxes, fill=False, width=7)

        path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets", "fakedata", "draw_boxes_vanilla.png")
        if not os.path.exists(path):
            res = Image.fromarray(result.permute(1, 2, 0).contiguous().numpy())
            res.save(path)

        expected = torch.as_tensor(np.array(Image.open(path))).permute(2, 0, 1)
        self.assertTrue(torch.equal(result, expected))
        # Check if modification is not in place
        self.assertTrue(torch.all(torch.eq(boxes, boxes_cp)).item())
        self.assertTrue(torch.all(torch.eq(img, img_cp)).item())

    def test_draw_invalid_boxes(self):
        img_tp = ((1, 1, 1), (1, 2, 3))
        img_wrong1 = torch.full((3, 5, 5), 255, dtype=torch.float)
        img_wrong2 = torch.full((1, 3, 5, 5), 255, dtype=torch.uint8)
        boxes = torch.tensor([[0, 0, 20, 20], [0, 0, 0, 0],
                             [10, 15, 30, 35], [23, 35, 93, 95]], dtype=torch.float)
        self.assertRaises(TypeError, utils.draw_bounding_boxes, img_tp, boxes)
        self.assertRaises(ValueError, utils.draw_bounding_boxes, img_wrong1, boxes)
        self.assertRaises(ValueError, utils.draw_bounding_boxes, img_wrong2, boxes)
155

156
157
    def test_draw_segmentation_masks_colors(self):
        img = torch.full((3, 5, 5), 255, dtype=torch.uint8)
158
159
        img_cp = img.clone()
        masks_cp = masks.clone()
160
161
162
163
164
165
166
167
168
169
170
171
        colors = ["#FF00FF", (0, 255, 0), "red"]
        result = utils.draw_segmentation_masks(img, masks, colors=colors)

        path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets",
                            "fakedata", "draw_segm_masks_colors_util.png")

        if not os.path.exists(path):
            res = Image.fromarray(result.permute(1, 2, 0).contiguous().numpy())
            res.save(path)

        expected = torch.as_tensor(np.array(Image.open(path))).permute(2, 0, 1)
        self.assertTrue(torch.equal(result, expected))
172
173
174
        # Check if modification is not in place
        self.assertTrue(torch.all(torch.eq(img, img_cp)).item())
        self.assertTrue(torch.all(torch.eq(masks, masks_cp)).item())
175
176
177

    def test_draw_segmentation_masks_no_colors(self):
        img = torch.full((3, 20, 20), 255, dtype=torch.uint8)
178
179
        img_cp = img.clone()
        masks_cp = masks.clone()
180
181
182
183
184
185
186
187
188
189
190
        result = utils.draw_segmentation_masks(img, masks, colors=None)

        path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets",
                            "fakedata", "draw_segm_masks_no_colors_util.png")

        if not os.path.exists(path):
            res = Image.fromarray(result.permute(1, 2, 0).contiguous().numpy())
            res.save(path)

        expected = torch.as_tensor(np.array(Image.open(path))).permute(2, 0, 1)
        self.assertTrue(torch.equal(result, expected))
191
192
193
194
195
196
197
198
199
200
201
202
203
204
        # Check if modification is not in place
        self.assertTrue(torch.all(torch.eq(img, img_cp)).item())
        self.assertTrue(torch.all(torch.eq(masks, masks_cp)).item())

    def test_draw_invalid_masks(self):
        img_tp = ((1, 1, 1), (1, 2, 3))
        img_wrong1 = torch.full((3, 5, 5), 255, dtype=torch.float)
        img_wrong2 = torch.full((1, 3, 5, 5), 255, dtype=torch.uint8)
        img_wrong3 = torch.full((4, 5, 5), 255, dtype=torch.uint8)

        self.assertRaises(TypeError, utils.draw_segmentation_masks, img_tp, masks)
        self.assertRaises(ValueError, utils.draw_segmentation_masks, img_wrong1, masks)
        self.assertRaises(ValueError, utils.draw_segmentation_masks, img_wrong2, masks)
        self.assertRaises(ValueError, utils.draw_segmentation_masks, img_wrong3, masks)
205

206
207
208

if __name__ == '__main__':
    unittest.main()