test_functional_tensor.py 9.6 KB
Newer Older
ekka's avatar
ekka committed
1
import torch
2
from torch import Tensor
ekka's avatar
ekka committed
3
import torchvision.transforms as transforms
4
import torchvision.transforms.functional_tensor as F_t
ekka's avatar
ekka committed
5
import torchvision.transforms.functional as F
6
import numpy as np
7
import unittest
ekka's avatar
ekka committed
8
import random
9
from torch.jit.annotations import Optional, List, BroadcastingList2, Tuple
10
11
12
13
14


class Tester(unittest.TestCase):

    def test_vflip(self):
15
        script_vflip = torch.jit.script(F_t.vflip)
16
        img_tensor = torch.randn(3, 16, 16)
17
        img_tensor_clone = img_tensor.clone()
18
19
20
21
        vflipped_img = F_t.vflip(img_tensor)
        vflipped_img_again = F_t.vflip(vflipped_img)
        self.assertEqual(vflipped_img.shape, img_tensor.shape)
        self.assertTrue(torch.equal(img_tensor, vflipped_img_again))
22
        self.assertTrue(torch.equal(img_tensor, img_tensor_clone))
23
24
25
        # scriptable function test
        vflipped_img_script = script_vflip(img_tensor)
        self.assertTrue(torch.equal(vflipped_img, vflipped_img_script))
26
27

    def test_hflip(self):
28
        script_hflip = torch.jit.script(F_t.hflip)
29
        img_tensor = torch.randn(3, 16, 16)
30
        img_tensor_clone = img_tensor.clone()
31
32
33
34
        hflipped_img = F_t.hflip(img_tensor)
        hflipped_img_again = F_t.hflip(hflipped_img)
        self.assertEqual(hflipped_img.shape, img_tensor.shape)
        self.assertTrue(torch.equal(img_tensor, hflipped_img_again))
35
        self.assertTrue(torch.equal(img_tensor, img_tensor_clone))
36
37
38
        # scriptable function test
        hflipped_img_script = script_hflip(img_tensor)
        self.assertTrue(torch.equal(hflipped_img, hflipped_img_script))
39

ekka's avatar
ekka committed
40
    def test_crop(self):
41
        script_crop = torch.jit.script(F_t.crop)
ekka's avatar
ekka committed
42
        img_tensor = torch.randint(0, 255, (3, 16, 16), dtype=torch.uint8)
43
        img_tensor_clone = img_tensor.clone()
ekka's avatar
ekka committed
44
45
46
47
48
49
50
51
        top = random.randint(0, 15)
        left = random.randint(0, 15)
        height = random.randint(1, 16 - top)
        width = random.randint(1, 16 - left)
        img_cropped = F_t.crop(img_tensor, top, left, height, width)
        img_PIL = transforms.ToPILImage()(img_tensor)
        img_PIL_cropped = F.crop(img_PIL, top, left, height, width)
        img_cropped_GT = transforms.ToTensor()(img_PIL_cropped)
52
        self.assertTrue(torch.equal(img_tensor, img_tensor_clone))
ekka's avatar
ekka committed
53
54
        self.assertTrue(torch.equal(img_cropped, (img_cropped_GT * 255).to(torch.uint8)),
                        "functional_tensor crop not working")
55
56
57
        # scriptable function test
        cropped_img_script = script_crop(img_tensor, top, left, height, width)
        self.assertTrue(torch.equal(img_cropped, cropped_img_script))
ekka's avatar
ekka committed
58

59
    def test_adjustments(self):
60
61
62
63
64
65
66
        script_adjust_brightness = torch.jit.script(F_t.adjust_brightness)
        script_adjust_contrast = torch.jit.script(F_t.adjust_contrast)
        script_adjust_saturation = torch.jit.script(F_t.adjust_saturation)

        fns = ((F.adjust_brightness, F_t.adjust_brightness, script_adjust_brightness),
               (F.adjust_contrast, F_t.adjust_contrast, script_adjust_contrast),
               (F.adjust_saturation, F_t.adjust_saturation, script_adjust_saturation))
67
68
69
70
71
72
73
74
75
76
77
78

        for _ in range(20):
            channels = 3
            dims = torch.randint(1, 50, (2,))
            shape = (channels, dims[0], dims[1])

            if torch.randint(0, 2, (1,)) == 0:
                img = torch.rand(*shape, dtype=torch.float)
            else:
                img = torch.randint(0, 256, shape, dtype=torch.uint8)

            factor = 3 * torch.rand(1)
79
            img_clone = img.clone()
80
            for f, ft, sft in fns:
81
82

                ft_img = ft(img, factor)
83
                sft_img = sft(img, factor)
84
85
                if not img.dtype.is_floating_point:
                    ft_img = ft_img.to(torch.float) / 255
86
                    sft_img = sft_img.to(torch.float) / 255
87
88
89
90
91
92
93
94

                img_pil = transforms.ToPILImage()(img)
                f_img_pil = f(img_pil, factor)
                f_img = transforms.ToTensor()(f_img_pil)

                # F uses uint8 and F_t uses float, so there is a small
                # difference in values caused by (at most 5) truncations.
                max_diff = (ft_img - f_img).abs().max()
95
                max_diff_scripted = (sft_img - f_img).abs().max()
96
                self.assertLess(max_diff, 5 / 255 + 1e-5)
97
                self.assertLess(max_diff_scripted, 5 / 255 + 1e-5)
98
                self.assertTrue(torch.equal(img, img_clone))
99

100
    def test_rgb_to_grayscale(self):
101
        script_rgb_to_grayscale = torch.jit.script(F_t.rgb_to_grayscale)
102
        img_tensor = torch.randint(0, 255, (3, 16, 16), dtype=torch.uint8)
103
        img_tensor_clone = img_tensor.clone()
104
105
106
107
        grayscale_tensor = F_t.rgb_to_grayscale(img_tensor).to(int)
        grayscale_pil_img = torch.tensor(np.array(F.to_grayscale(F.to_pil_image(img_tensor)))).to(int)
        max_diff = (grayscale_tensor - grayscale_pil_img).abs().max()
        self.assertLess(max_diff, 1.0001)
108
        self.assertTrue(torch.equal(img_tensor, img_tensor_clone))
109
110
111
        # scriptable function test
        grayscale_script = script_rgb_to_grayscale(img_tensor).to(int)
        self.assertTrue(torch.equal(grayscale_script, grayscale_tensor))
112

113
    def test_center_crop(self):
114
        script_center_crop = torch.jit.script(F_t.center_crop)
115
        img_tensor = torch.randint(0, 255, (1, 32, 32), dtype=torch.uint8)
116
        img_tensor_clone = img_tensor.clone()
117
118
119
120
        cropped_tensor = F_t.center_crop(img_tensor, [10, 10])
        cropped_pil_image = F.center_crop(transforms.ToPILImage()(img_tensor), [10, 10])
        cropped_pil_tensor = (transforms.ToTensor()(cropped_pil_image) * 255).to(torch.uint8)
        self.assertTrue(torch.equal(cropped_tensor, cropped_pil_tensor))
121
        self.assertTrue(torch.equal(img_tensor, img_tensor_clone))
122
123
124
        # scriptable function test
        cropped_script = script_center_crop(img_tensor, [10, 10])
        self.assertTrue(torch.equal(cropped_script, cropped_tensor))
125
126

    def test_five_crop(self):
127
        script_five_crop = torch.jit.script(F_t.five_crop)
128
        img_tensor = torch.randint(0, 255, (1, 32, 32), dtype=torch.uint8)
129
        img_tensor_clone = img_tensor.clone()
130
131
132
133
134
135
136
137
138
139
140
141
        cropped_tensor = F_t.five_crop(img_tensor, [10, 10])
        cropped_pil_image = F.five_crop(transforms.ToPILImage()(img_tensor), [10, 10])
        self.assertTrue(torch.equal(cropped_tensor[0],
                                    (transforms.ToTensor()(cropped_pil_image[0]) * 255).to(torch.uint8)))
        self.assertTrue(torch.equal(cropped_tensor[1],
                                    (transforms.ToTensor()(cropped_pil_image[2]) * 255).to(torch.uint8)))
        self.assertTrue(torch.equal(cropped_tensor[2],
                                    (transforms.ToTensor()(cropped_pil_image[1]) * 255).to(torch.uint8)))
        self.assertTrue(torch.equal(cropped_tensor[3],
                                    (transforms.ToTensor()(cropped_pil_image[3]) * 255).to(torch.uint8)))
        self.assertTrue(torch.equal(cropped_tensor[4],
                                    (transforms.ToTensor()(cropped_pil_image[4]) * 255).to(torch.uint8)))
142
        self.assertTrue(torch.equal(img_tensor, img_tensor_clone))
143
144
145
146
        # scriptable function test
        cropped_script = script_five_crop(img_tensor, [10, 10])
        for cropped_script_img, cropped_tensor_img in zip(cropped_script, cropped_tensor):
            self.assertTrue(torch.equal(cropped_script_img, cropped_tensor_img))
147
148

    def test_ten_crop(self):
149
        script_ten_crop = torch.jit.script(F_t.ten_crop)
150
        img_tensor = torch.randint(0, 255, (1, 32, 32), dtype=torch.uint8)
151
        img_tensor_clone = img_tensor.clone()
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
        cropped_tensor = F_t.ten_crop(img_tensor, [10, 10])
        cropped_pil_image = F.ten_crop(transforms.ToPILImage()(img_tensor), [10, 10])
        self.assertTrue(torch.equal(cropped_tensor[0],
                                    (transforms.ToTensor()(cropped_pil_image[0]) * 255).to(torch.uint8)))
        self.assertTrue(torch.equal(cropped_tensor[1],
                                    (transforms.ToTensor()(cropped_pil_image[2]) * 255).to(torch.uint8)))
        self.assertTrue(torch.equal(cropped_tensor[2],
                                    (transforms.ToTensor()(cropped_pil_image[1]) * 255).to(torch.uint8)))
        self.assertTrue(torch.equal(cropped_tensor[3],
                                    (transforms.ToTensor()(cropped_pil_image[3]) * 255).to(torch.uint8)))
        self.assertTrue(torch.equal(cropped_tensor[4],
                                    (transforms.ToTensor()(cropped_pil_image[4]) * 255).to(torch.uint8)))
        self.assertTrue(torch.equal(cropped_tensor[5],
                                    (transforms.ToTensor()(cropped_pil_image[5]) * 255).to(torch.uint8)))
        self.assertTrue(torch.equal(cropped_tensor[6],
                                    (transforms.ToTensor()(cropped_pil_image[7]) * 255).to(torch.uint8)))
        self.assertTrue(torch.equal(cropped_tensor[7],
                                    (transforms.ToTensor()(cropped_pil_image[6]) * 255).to(torch.uint8)))
        self.assertTrue(torch.equal(cropped_tensor[8],
                                    (transforms.ToTensor()(cropped_pil_image[8]) * 255).to(torch.uint8)))
        self.assertTrue(torch.equal(cropped_tensor[9],
                                    (transforms.ToTensor()(cropped_pil_image[9]) * 255).to(torch.uint8)))
174
        self.assertTrue(torch.equal(img_tensor, img_tensor_clone))
175
176
177
178
        # scriptable function test
        cropped_script = script_ten_crop(img_tensor, [10, 10])
        for cropped_script_img, cropped_tensor_img in zip(cropped_script, cropped_tensor):
            self.assertTrue(torch.equal(cropped_script_img, cropped_tensor_img))
179

180
181
182

if __name__ == '__main__':
    unittest.main()