test_utils.py 6.36 KB
Newer Older
1
import numpy as np
2
import os
Francisco Massa's avatar
Francisco Massa committed
3
import sys
4
import tempfile
5
6
7
import torch
import torchvision.utils as utils
import unittest
8
9
10
from io import BytesIO
import torchvision.transforms.functional as F
from PIL import Image
11

12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
masks = torch.tensor([
    [
        [-2.2799, -2.2799, -2.2799, -2.2799, -2.2799],
        [5.0914, 5.0914, 5.0914, 5.0914, 5.0914],
        [-2.2799, -2.2799, -2.2799, -2.2799, -2.2799],
        [-2.2799, -2.2799, -2.2799, -2.2799, -2.2799],
        [-2.2799, -2.2799, -2.2799, -2.2799, -2.2799]
    ],
    [
        [5.0914, 5.0914, 5.0914, 5.0914, 5.0914],
        [-2.2799, -2.2799, -2.2799, -2.2799, -2.2799],
        [5.0914, 5.0914, 5.0914, 5.0914, 5.0914],
        [5.0914, 5.0914, 5.0914, 5.0914, 5.0914],
        [-1.4541, -1.4541, -1.4541, -1.4541, -1.4541]
    ],
    [
        [-1.4541, -1.4541, -1.4541, -1.4541, -1.4541],
        [-1.4541, -1.4541, -1.4541, -1.4541, -1.4541],
        [-1.4541, -1.4541, -1.4541, -1.4541, -1.4541],
        [-1.4541, -1.4541, -1.4541, -1.4541, -1.4541],
        [5.0914, 5.0914, 5.0914, 5.0914, 5.0914],
    ]
], dtype=torch.float)

36
37
38
39
40
41
42
43

class Tester(unittest.TestCase):

    def test_make_grid_not_inplace(self):
        t = torch.rand(5, 3, 10, 10)
        t_clone = t.clone()

        utils.make_grid(t, normalize=False)
44
        self.assertTrue(torch.equal(t, t_clone), 'make_grid modified tensor in-place')
45
46

        utils.make_grid(t, normalize=True, scale_each=False)
47
        self.assertTrue(torch.equal(t, t_clone), 'make_grid modified tensor in-place')
48
49

        utils.make_grid(t, normalize=True, scale_each=True)
50
        self.assertTrue(torch.equal(t, t_clone), 'make_grid modified tensor in-place')
51

52
53
54
55
56
57
58
59
60
61
62
63
64
65
    def test_normalize_in_make_grid(self):
        t = torch.rand(5, 3, 10, 10) * 255
        norm_max = torch.tensor(1.0)
        norm_min = torch.tensor(0.0)

        grid = utils.make_grid(t, normalize=True)
        grid_max = torch.max(grid)
        grid_min = torch.min(grid)

        # Rounding the result to one decimal for comparison
        n_digits = 1
        rounded_grid_max = torch.round(grid_max * 10 ** n_digits) / (10 ** n_digits)
        rounded_grid_min = torch.round(grid_min * 10 ** n_digits) / (10 ** n_digits)

66
67
        self.assertTrue(torch.equal(norm_max, rounded_grid_max), 'Normalized max is not equal to 1')
        self.assertTrue(torch.equal(norm_min, rounded_grid_min), 'Normalized min is not equal to 0')
68

Nicolas Hug's avatar
Nicolas Hug committed
69
    @unittest.skipIf(sys.platform in ('win32', 'cygwin'), 'temporarily disabled on Windows')
70
    def test_save_image(self):
71
72
73
        with tempfile.NamedTemporaryFile(suffix='.png') as f:
            t = torch.rand(2, 3, 64, 64)
            utils.save_image(t, f.name)
74
            self.assertTrue(os.path.exists(f.name), 'The image is not present after save')
75

Nicolas Hug's avatar
Nicolas Hug committed
76
    @unittest.skipIf(sys.platform in ('win32', 'cygwin'), 'temporarily disabled on Windows')
77
78
79
80
    def test_save_image_single_pixel(self):
        with tempfile.NamedTemporaryFile(suffix='.png') as f:
            t = torch.rand(1, 3, 1, 1)
            utils.save_image(t, f.name)
81
            self.assertTrue(os.path.exists(f.name), 'The pixel image is not present after save')
82

Nicolas Hug's avatar
Nicolas Hug committed
83
    @unittest.skipIf(sys.platform in ('win32', 'cygwin'), 'temporarily disabled on Windows')
84
85
86
87
88
89
90
91
    def test_save_image_file_object(self):
        with tempfile.NamedTemporaryFile(suffix='.png') as f:
            t = torch.rand(2, 3, 64, 64)
            utils.save_image(t, f.name)
            img_orig = Image.open(f.name)
            fp = BytesIO()
            utils.save_image(t, fp, format='png')
            img_bytes = Image.open(fp)
92
93
            self.assertTrue(torch.equal(F.to_tensor(img_orig), F.to_tensor(img_bytes)),
                            'Image not stored in file object')
94

Nicolas Hug's avatar
Nicolas Hug committed
95
    @unittest.skipIf(sys.platform in ('win32', 'cygwin'), 'temporarily disabled on Windows')
96
97
98
99
100
101
102
103
    def test_save_image_single_pixel_file_object(self):
        with tempfile.NamedTemporaryFile(suffix='.png') as f:
            t = torch.rand(1, 3, 1, 1)
            utils.save_image(t, f.name)
            img_orig = Image.open(f.name)
            fp = BytesIO()
            utils.save_image(t, fp, format='png')
            img_bytes = Image.open(fp)
104
105
            self.assertTrue(torch.equal(F.to_tensor(img_orig), F.to_tensor(img_bytes)),
                            'Pixel Image not stored in file object')
106

107
108
109
110
111
112
    def test_draw_boxes(self):
        img = torch.full((3, 100, 100), 255, dtype=torch.uint8)
        boxes = torch.tensor([[0, 0, 20, 20], [0, 0, 0, 0],
                             [10, 15, 30, 35], [23, 35, 93, 95]], dtype=torch.float)
        labels = ["a", "b", "c", "d"]
        colors = ["green", "#FF00FF", (0, 255, 0), "red"]
113
        result = utils.draw_bounding_boxes(img, boxes, labels=labels, colors=colors, fill=True)
114
115
116

        path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets", "fakedata", "draw_boxes_util.png")
        if not os.path.exists(path):
117
118
            res = Image.fromarray(result.permute(1, 2, 0).contiguous().numpy())
            res.save(path)
119

120
        expected = torch.as_tensor(np.array(Image.open(path))).permute(2, 0, 1)
121
122
        self.assertTrue(torch.equal(result, expected))

123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
    def test_draw_segmentation_masks_colors(self):
        img = torch.full((3, 5, 5), 255, dtype=torch.uint8)
        colors = ["#FF00FF", (0, 255, 0), "red"]
        result = utils.draw_segmentation_masks(img, masks, colors=colors)

        path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets",
                            "fakedata", "draw_segm_masks_colors_util.png")

        if not os.path.exists(path):
            res = Image.fromarray(result.permute(1, 2, 0).contiguous().numpy())
            res.save(path)

        expected = torch.as_tensor(np.array(Image.open(path))).permute(2, 0, 1)
        self.assertTrue(torch.equal(result, expected))

    def test_draw_segmentation_masks_no_colors(self):
        img = torch.full((3, 20, 20), 255, dtype=torch.uint8)
        result = utils.draw_segmentation_masks(img, masks, colors=None)

        path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets",
                            "fakedata", "draw_segm_masks_no_colors_util.png")

        if not os.path.exists(path):
            res = Image.fromarray(result.permute(1, 2, 0).contiguous().numpy())
            res.save(path)

        expected = torch.as_tensor(np.array(Image.open(path))).permute(2, 0, 1)
        self.assertTrue(torch.equal(result, expected))

152
153
154

if __name__ == '__main__':
    unittest.main()