test_utils.py 3.05 KB
Newer Older
1
import os
Francisco Massa's avatar
Francisco Massa committed
2
import sys
3
import tempfile
4
5
6
import torch
import torchvision.utils as utils
import unittest
7
8
9
from io import BytesIO
import torchvision.transforms.functional as F
from PIL import Image
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26


class Tester(unittest.TestCase):

    def test_make_grid_not_inplace(self):
        t = torch.rand(5, 3, 10, 10)
        t_clone = t.clone()

        utils.make_grid(t, normalize=False)
        assert torch.equal(t, t_clone), 'make_grid modified tensor in-place'

        utils.make_grid(t, normalize=True, scale_each=False)
        assert torch.equal(t, t_clone), 'make_grid modified tensor in-place'

        utils.make_grid(t, normalize=True, scale_each=True)
        assert torch.equal(t, t_clone), 'make_grid modified tensor in-place'

27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
    def test_normalize_in_make_grid(self):
        t = torch.rand(5, 3, 10, 10) * 255
        norm_max = torch.tensor(1.0)
        norm_min = torch.tensor(0.0)

        grid = utils.make_grid(t, normalize=True)
        grid_max = torch.max(grid)
        grid_min = torch.min(grid)

        # Rounding the result to one decimal for comparison
        n_digits = 1
        rounded_grid_max = torch.round(grid_max * 10 ** n_digits) / (10 ** n_digits)
        rounded_grid_min = torch.round(grid_min * 10 ** n_digits) / (10 ** n_digits)

        assert torch.equal(norm_max, rounded_grid_max), 'Normalized max is not equal to 1'
        assert torch.equal(norm_min, rounded_grid_min), 'Normalized min is not equal to 0'

Francisco Massa's avatar
Francisco Massa committed
44
    @unittest.skipIf('win' in sys.platform, 'temporarily disabled on Windows')
45
    def test_save_image(self):
46
47
48
49
        with tempfile.NamedTemporaryFile(suffix='.png') as f:
            t = torch.rand(2, 3, 64, 64)
            utils.save_image(t, f.name)
            assert os.path.exists(f.name), 'The image is not present after save'
50

Francisco Massa's avatar
Francisco Massa committed
51
    @unittest.skipIf('win' in sys.platform, 'temporarily disabled on Windows')
52
53
54
55
56
57
    def test_save_image_single_pixel(self):
        with tempfile.NamedTemporaryFile(suffix='.png') as f:
            t = torch.rand(1, 3, 1, 1)
            utils.save_image(t, f.name)
            assert os.path.exists(f.name), 'The pixel image is not present after save'

58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
    def test_save_image_file_object(self):
        with tempfile.NamedTemporaryFile(suffix='.png') as f:
            t = torch.rand(2, 3, 64, 64)
            utils.save_image(t, f.name)
            img_orig = Image.open(f.name)
            fp = BytesIO()
            utils.save_image(t, fp, format='png')
            img_bytes = Image.open(fp)
            assert torch.equal(F.to_tensor(img_orig), F.to_tensor(img_bytes)), 'Image not stored in file object'

    def test_save_image_single_pixel_file_object(self):
        with tempfile.NamedTemporaryFile(suffix='.png') as f:
            t = torch.rand(1, 3, 1, 1)
            utils.save_image(t, f.name)
            img_orig = Image.open(f.name)
            fp = BytesIO()
            utils.save_image(t, fp, format='png')
            img_bytes = Image.open(fp)
            assert torch.equal(F.to_tensor(img_orig), F.to_tensor(img_bytes)), 'Pixel Image not stored in file object'

78
79
80

if __name__ == '__main__':
    unittest.main()