test_utils.py 3.2 KB
Newer Older
1
import os
Francisco Massa's avatar
Francisco Massa committed
2
import sys
3
import tempfile
4
5
6
import torch
import torchvision.utils as utils
import unittest
7
8
9
from io import BytesIO
import torchvision.transforms.functional as F
from PIL import Image
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26


class Tester(unittest.TestCase):

    def test_make_grid_not_inplace(self):
        t = torch.rand(5, 3, 10, 10)
        t_clone = t.clone()

        utils.make_grid(t, normalize=False)
        assert torch.equal(t, t_clone), 'make_grid modified tensor in-place'

        utils.make_grid(t, normalize=True, scale_each=False)
        assert torch.equal(t, t_clone), 'make_grid modified tensor in-place'

        utils.make_grid(t, normalize=True, scale_each=True)
        assert torch.equal(t, t_clone), 'make_grid modified tensor in-place'

27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
    def test_normalize_in_make_grid(self):
        t = torch.rand(5, 3, 10, 10) * 255
        norm_max = torch.tensor(1.0)
        norm_min = torch.tensor(0.0)

        grid = utils.make_grid(t, normalize=True)
        grid_max = torch.max(grid)
        grid_min = torch.min(grid)

        # Rounding the result to one decimal for comparison
        n_digits = 1
        rounded_grid_max = torch.round(grid_max * 10 ** n_digits) / (10 ** n_digits)
        rounded_grid_min = torch.round(grid_min * 10 ** n_digits) / (10 ** n_digits)

        assert torch.equal(norm_max, rounded_grid_max), 'Normalized max is not equal to 1'
        assert torch.equal(norm_min, rounded_grid_min), 'Normalized min is not equal to 0'

Francisco Massa's avatar
Francisco Massa committed
44
    @unittest.skipIf('win' in sys.platform, 'temporarily disabled on Windows')
45
    def test_save_image(self):
46
47
48
49
        with tempfile.NamedTemporaryFile(suffix='.png') as f:
            t = torch.rand(2, 3, 64, 64)
            utils.save_image(t, f.name)
            assert os.path.exists(f.name), 'The image is not present after save'
50

Francisco Massa's avatar
Francisco Massa committed
51
    @unittest.skipIf('win' in sys.platform, 'temporarily disabled on Windows')
52
53
54
55
56
57
    def test_save_image_single_pixel(self):
        with tempfile.NamedTemporaryFile(suffix='.png') as f:
            t = torch.rand(1, 3, 1, 1)
            utils.save_image(t, f.name)
            assert os.path.exists(f.name), 'The pixel image is not present after save'

58
    @unittest.skipIf('win' in sys.platform, 'temporarily disabled on Windows')
59
60
61
62
63
64
65
66
67
68
    def test_save_image_file_object(self):
        with tempfile.NamedTemporaryFile(suffix='.png') as f:
            t = torch.rand(2, 3, 64, 64)
            utils.save_image(t, f.name)
            img_orig = Image.open(f.name)
            fp = BytesIO()
            utils.save_image(t, fp, format='png')
            img_bytes = Image.open(fp)
            assert torch.equal(F.to_tensor(img_orig), F.to_tensor(img_bytes)), 'Image not stored in file object'

69
    @unittest.skipIf('win' in sys.platform, 'temporarily disabled on Windows')
70
71
72
73
74
75
76
77
78
79
    def test_save_image_single_pixel_file_object(self):
        with tempfile.NamedTemporaryFile(suffix='.png') as f:
            t = torch.rand(1, 3, 1, 1)
            utils.save_image(t, f.name)
            img_orig = Image.open(f.name)
            fp = BytesIO()
            utils.save_image(t, fp, format='png')
            img_bytes = Image.open(fp)
            assert torch.equal(F.to_tensor(img_orig), F.to_tensor(img_bytes)), 'Pixel Image not stored in file object'

80
81
82

if __name__ == '__main__':
    unittest.main()