test_utils.py 3.34 KB
Newer Older
1
import os
Francisco Massa's avatar
Francisco Massa committed
2
import sys
3
import tempfile
4
5
6
import torch
import torchvision.utils as utils
import unittest
7
8
9
from io import BytesIO
import torchvision.transforms.functional as F
from PIL import Image
10
11
12
13
14
15
16
17
18


class Tester(unittest.TestCase):

    def test_make_grid_not_inplace(self):
        t = torch.rand(5, 3, 10, 10)
        t_clone = t.clone()

        utils.make_grid(t, normalize=False)
19
        self.assertTrue(torch.equal(t, t_clone), 'make_grid modified tensor in-place')
20
21

        utils.make_grid(t, normalize=True, scale_each=False)
22
        self.assertTrue(torch.equal(t, t_clone), 'make_grid modified tensor in-place')
23
24

        utils.make_grid(t, normalize=True, scale_each=True)
25
        self.assertTrue(torch.equal(t, t_clone), 'make_grid modified tensor in-place')
26

27
28
29
30
31
32
33
34
35
36
37
38
39
40
    def test_normalize_in_make_grid(self):
        t = torch.rand(5, 3, 10, 10) * 255
        norm_max = torch.tensor(1.0)
        norm_min = torch.tensor(0.0)

        grid = utils.make_grid(t, normalize=True)
        grid_max = torch.max(grid)
        grid_min = torch.min(grid)

        # Rounding the result to one decimal for comparison
        n_digits = 1
        rounded_grid_max = torch.round(grid_max * 10 ** n_digits) / (10 ** n_digits)
        rounded_grid_min = torch.round(grid_min * 10 ** n_digits) / (10 ** n_digits)

41
42
        self.assertTrue(torch.equal(norm_max, rounded_grid_max), 'Normalized max is not equal to 1')
        self.assertTrue(torch.equal(norm_min, rounded_grid_min), 'Normalized min is not equal to 0')
43

Francisco Massa's avatar
Francisco Massa committed
44
    @unittest.skipIf('win' in sys.platform, 'temporarily disabled on Windows')
45
    def test_save_image(self):
46
47
48
        with tempfile.NamedTemporaryFile(suffix='.png') as f:
            t = torch.rand(2, 3, 64, 64)
            utils.save_image(t, f.name)
49
            self.assertTrue(os.path.exists(f.name), 'The image is not present after save')
50

Francisco Massa's avatar
Francisco Massa committed
51
    @unittest.skipIf('win' in sys.platform, 'temporarily disabled on Windows')
52
53
54
55
    def test_save_image_single_pixel(self):
        with tempfile.NamedTemporaryFile(suffix='.png') as f:
            t = torch.rand(1, 3, 1, 1)
            utils.save_image(t, f.name)
56
            self.assertTrue(os.path.exists(f.name), 'The pixel image is not present after save')
57

58
    @unittest.skipIf('win' in sys.platform, 'temporarily disabled on Windows')
59
60
61
62
63
64
65
66
    def test_save_image_file_object(self):
        with tempfile.NamedTemporaryFile(suffix='.png') as f:
            t = torch.rand(2, 3, 64, 64)
            utils.save_image(t, f.name)
            img_orig = Image.open(f.name)
            fp = BytesIO()
            utils.save_image(t, fp, format='png')
            img_bytes = Image.open(fp)
67
68
            self.assertTrue(torch.equal(F.to_tensor(img_orig), F.to_tensor(img_bytes)),
                            'Image not stored in file object')
69

70
    @unittest.skipIf('win' in sys.platform, 'temporarily disabled on Windows')
71
72
73
74
75
76
77
78
    def test_save_image_single_pixel_file_object(self):
        with tempfile.NamedTemporaryFile(suffix='.png') as f:
            t = torch.rand(1, 3, 1, 1)
            utils.save_image(t, f.name)
            img_orig = Image.open(f.name)
            fp = BytesIO()
            utils.save_image(t, fp, format='png')
            img_bytes = Image.open(fp)
79
80
            self.assertTrue(torch.equal(F.to_tensor(img_orig), F.to_tensor(img_bytes)),
                            'Pixel Image not stored in file object')
81

82
83
84

if __name__ == '__main__':
    unittest.main()