test_utils.py 2.05 KB
Newer Older
1
import os
Francisco Massa's avatar
Francisco Massa committed
2
import sys
3
import tempfile
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
import torch
import torchvision.utils as utils
import unittest


class Tester(unittest.TestCase):

    def test_make_grid_not_inplace(self):
        t = torch.rand(5, 3, 10, 10)
        t_clone = t.clone()

        utils.make_grid(t, normalize=False)
        assert torch.equal(t, t_clone), 'make_grid modified tensor in-place'

        utils.make_grid(t, normalize=True, scale_each=False)
        assert torch.equal(t, t_clone), 'make_grid modified tensor in-place'

        utils.make_grid(t, normalize=True, scale_each=True)
        assert torch.equal(t, t_clone), 'make_grid modified tensor in-place'

24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
    def test_normalize_in_make_grid(self):
        t = torch.rand(5, 3, 10, 10) * 255
        norm_max = torch.tensor(1.0)
        norm_min = torch.tensor(0.0)

        grid = utils.make_grid(t, normalize=True)
        grid_max = torch.max(grid)
        grid_min = torch.min(grid)

        # Rounding the result to one decimal for comparison
        n_digits = 1
        rounded_grid_max = torch.round(grid_max * 10 ** n_digits) / (10 ** n_digits)
        rounded_grid_min = torch.round(grid_min * 10 ** n_digits) / (10 ** n_digits)

        assert torch.equal(norm_max, rounded_grid_max), 'Normalized max is not equal to 1'
        assert torch.equal(norm_min, rounded_grid_min), 'Normalized min is not equal to 0'

Francisco Massa's avatar
Francisco Massa committed
41
    @unittest.skipIf('win' in sys.platform, 'temporarily disabled on Windows')
42
    def test_save_image(self):
43
44
45
46
        with tempfile.NamedTemporaryFile(suffix='.png') as f:
            t = torch.rand(2, 3, 64, 64)
            utils.save_image(t, f.name)
            assert os.path.exists(f.name), 'The image is not present after save'
47

Francisco Massa's avatar
Francisco Massa committed
48
    @unittest.skipIf('win' in sys.platform, 'temporarily disabled on Windows')
49
50
51
52
53
54
    def test_save_image_single_pixel(self):
        with tempfile.NamedTemporaryFile(suffix='.png') as f:
            t = torch.rand(1, 3, 1, 1)
            utils.save_image(t, f.name)
            assert os.path.exists(f.name), 'The pixel image is not present after save'

55
56
57

if __name__ == '__main__':
    unittest.main()