test_functional_tensor.py 11.5 KB
Newer Older
ekka's avatar
ekka committed
1
import torch
2
from torch import Tensor
ekka's avatar
ekka committed
3
import torchvision.transforms as transforms
4
import torchvision.transforms.functional_tensor as F_t
ekka's avatar
ekka committed
5
import torchvision.transforms.functional as F
6
import numpy as np
7
import unittest
ekka's avatar
ekka committed
8
import random
9
import colorsys
10
from torch.jit.annotations import Optional, List, BroadcastingList2, Tuple
11
12
13
14
15


class Tester(unittest.TestCase):

    def test_vflip(self):
16
        script_vflip = torch.jit.script(F_t.vflip)
17
        img_tensor = torch.randn(3, 16, 16)
18
        img_tensor_clone = img_tensor.clone()
19
20
21
22
        vflipped_img = F_t.vflip(img_tensor)
        vflipped_img_again = F_t.vflip(vflipped_img)
        self.assertEqual(vflipped_img.shape, img_tensor.shape)
        self.assertTrue(torch.equal(img_tensor, vflipped_img_again))
23
        self.assertTrue(torch.equal(img_tensor, img_tensor_clone))
24
25
26
        # scriptable function test
        vflipped_img_script = script_vflip(img_tensor)
        self.assertTrue(torch.equal(vflipped_img, vflipped_img_script))
27
28

    def test_hflip(self):
29
        script_hflip = torch.jit.script(F_t.hflip)
30
        img_tensor = torch.randn(3, 16, 16)
31
        img_tensor_clone = img_tensor.clone()
32
33
34
35
        hflipped_img = F_t.hflip(img_tensor)
        hflipped_img_again = F_t.hflip(hflipped_img)
        self.assertEqual(hflipped_img.shape, img_tensor.shape)
        self.assertTrue(torch.equal(img_tensor, hflipped_img_again))
36
        self.assertTrue(torch.equal(img_tensor, img_tensor_clone))
37
38
39
        # scriptable function test
        hflipped_img_script = script_hflip(img_tensor)
        self.assertTrue(torch.equal(hflipped_img, hflipped_img_script))
40

ekka's avatar
ekka committed
41
    def test_crop(self):
42
        script_crop = torch.jit.script(F_t.crop)
ekka's avatar
ekka committed
43
        img_tensor = torch.randint(0, 255, (3, 16, 16), dtype=torch.uint8)
44
        img_tensor_clone = img_tensor.clone()
ekka's avatar
ekka committed
45
46
47
48
49
50
51
52
        top = random.randint(0, 15)
        left = random.randint(0, 15)
        height = random.randint(1, 16 - top)
        width = random.randint(1, 16 - left)
        img_cropped = F_t.crop(img_tensor, top, left, height, width)
        img_PIL = transforms.ToPILImage()(img_tensor)
        img_PIL_cropped = F.crop(img_PIL, top, left, height, width)
        img_cropped_GT = transforms.ToTensor()(img_PIL_cropped)
53
        self.assertTrue(torch.equal(img_tensor, img_tensor_clone))
ekka's avatar
ekka committed
54
55
        self.assertTrue(torch.equal(img_cropped, (img_cropped_GT * 255).to(torch.uint8)),
                        "functional_tensor crop not working")
56
57
58
        # scriptable function test
        cropped_img_script = script_crop(img_tensor, top, left, height, width)
        self.assertTrue(torch.equal(img_cropped, cropped_img_script))
ekka's avatar
ekka committed
59

60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
    def test_hsv2rgb(self):
        shape = (3, 100, 150)
        for _ in range(20):
            img = torch.rand(*shape, dtype=torch.float)
            ft_img = F_t._hsv2rgb(img).permute(1, 2, 0).flatten(0, 1)

            h, s, v, = img.unbind(0)
            h = h.flatten().numpy()
            s = s.flatten().numpy()
            v = v.flatten().numpy()

            rgb = []
            for h1, s1, v1 in zip(h, s, v):
                rgb.append(colorsys.hsv_to_rgb(h1, s1, v1))

            colorsys_img = torch.tensor(rgb, dtype=torch.float32)
            max_diff = (ft_img - colorsys_img).abs().max()
            self.assertLess(max_diff, 1e-5)

    def test_rgb2hsv(self):
        shape = (3, 150, 100)
        for _ in range(20):
            img = torch.rand(*shape, dtype=torch.float)
            ft_hsv_img = F_t._rgb2hsv(img).permute(1, 2, 0).flatten(0, 1)

            r, g, b, = img.unbind(0)
            r = r.flatten().numpy()
            g = g.flatten().numpy()
            b = b.flatten().numpy()

            hsv = []
            for r1, g1, b1 in zip(r, g, b):
                hsv.append(colorsys.rgb_to_hsv(r1, g1, b1))

            colorsys_img = torch.tensor(hsv, dtype=torch.float32)

            max_diff = (colorsys_img - ft_hsv_img).abs().max()
            self.assertLess(max_diff, 1e-5)

99
    def test_adjustments(self):
100
101
102
103
104
105
106
        script_adjust_brightness = torch.jit.script(F_t.adjust_brightness)
        script_adjust_contrast = torch.jit.script(F_t.adjust_contrast)
        script_adjust_saturation = torch.jit.script(F_t.adjust_saturation)

        fns = ((F.adjust_brightness, F_t.adjust_brightness, script_adjust_brightness),
               (F.adjust_contrast, F_t.adjust_contrast, script_adjust_contrast),
               (F.adjust_saturation, F_t.adjust_saturation, script_adjust_saturation))
107
108
109
110
111
112
113
114
115
116
117
118

        for _ in range(20):
            channels = 3
            dims = torch.randint(1, 50, (2,))
            shape = (channels, dims[0], dims[1])

            if torch.randint(0, 2, (1,)) == 0:
                img = torch.rand(*shape, dtype=torch.float)
            else:
                img = torch.randint(0, 256, shape, dtype=torch.uint8)

            factor = 3 * torch.rand(1)
119
            img_clone = img.clone()
120
            for f, ft, sft in fns:
121
122

                ft_img = ft(img, factor)
123
                sft_img = sft(img, factor)
124
125
                if not img.dtype.is_floating_point:
                    ft_img = ft_img.to(torch.float) / 255
126
                    sft_img = sft_img.to(torch.float) / 255
127
128
129
130
131
132
133
134

                img_pil = transforms.ToPILImage()(img)
                f_img_pil = f(img_pil, factor)
                f_img = transforms.ToTensor()(f_img_pil)

                # F uses uint8 and F_t uses float, so there is a small
                # difference in values caused by (at most 5) truncations.
                max_diff = (ft_img - f_img).abs().max()
135
                max_diff_scripted = (sft_img - f_img).abs().max()
136
                self.assertLess(max_diff, 5 / 255 + 1e-5)
137
                self.assertLess(max_diff_scripted, 5 / 255 + 1e-5)
138
                self.assertTrue(torch.equal(img, img_clone))
139

140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
            # test for class interface
            f = transforms.ColorJitter(brightness=factor.item())
            scripted_fn = torch.jit.script(f)
            scripted_fn(img)

            f = transforms.ColorJitter(contrast=factor.item())
            scripted_fn = torch.jit.script(f)
            scripted_fn(img)

            f = transforms.ColorJitter(saturation=factor.item())
            scripted_fn = torch.jit.script(f)
            scripted_fn(img)

        f = transforms.ColorJitter(brightness=1)
        scripted_fn = torch.jit.script(f)
        scripted_fn(img)

157
    def test_rgb_to_grayscale(self):
158
        script_rgb_to_grayscale = torch.jit.script(F_t.rgb_to_grayscale)
159
        img_tensor = torch.randint(0, 255, (3, 16, 16), dtype=torch.uint8)
160
        img_tensor_clone = img_tensor.clone()
161
162
163
164
        grayscale_tensor = F_t.rgb_to_grayscale(img_tensor).to(int)
        grayscale_pil_img = torch.tensor(np.array(F.to_grayscale(F.to_pil_image(img_tensor)))).to(int)
        max_diff = (grayscale_tensor - grayscale_pil_img).abs().max()
        self.assertLess(max_diff, 1.0001)
165
        self.assertTrue(torch.equal(img_tensor, img_tensor_clone))
166
167
168
        # scriptable function test
        grayscale_script = script_rgb_to_grayscale(img_tensor).to(int)
        self.assertTrue(torch.equal(grayscale_script, grayscale_tensor))
169

170
    def test_center_crop(self):
171
        script_center_crop = torch.jit.script(F_t.center_crop)
172
        img_tensor = torch.randint(0, 255, (1, 32, 32), dtype=torch.uint8)
173
        img_tensor_clone = img_tensor.clone()
174
175
176
177
        cropped_tensor = F_t.center_crop(img_tensor, [10, 10])
        cropped_pil_image = F.center_crop(transforms.ToPILImage()(img_tensor), [10, 10])
        cropped_pil_tensor = (transforms.ToTensor()(cropped_pil_image) * 255).to(torch.uint8)
        self.assertTrue(torch.equal(cropped_tensor, cropped_pil_tensor))
178
        self.assertTrue(torch.equal(img_tensor, img_tensor_clone))
179
180
181
        # scriptable function test
        cropped_script = script_center_crop(img_tensor, [10, 10])
        self.assertTrue(torch.equal(cropped_script, cropped_tensor))
182
183

    def test_five_crop(self):
184
        script_five_crop = torch.jit.script(F_t.five_crop)
185
        img_tensor = torch.randint(0, 255, (1, 32, 32), dtype=torch.uint8)
186
        img_tensor_clone = img_tensor.clone()
187
188
189
190
191
192
193
194
195
196
197
198
        cropped_tensor = F_t.five_crop(img_tensor, [10, 10])
        cropped_pil_image = F.five_crop(transforms.ToPILImage()(img_tensor), [10, 10])
        self.assertTrue(torch.equal(cropped_tensor[0],
                                    (transforms.ToTensor()(cropped_pil_image[0]) * 255).to(torch.uint8)))
        self.assertTrue(torch.equal(cropped_tensor[1],
                                    (transforms.ToTensor()(cropped_pil_image[2]) * 255).to(torch.uint8)))
        self.assertTrue(torch.equal(cropped_tensor[2],
                                    (transforms.ToTensor()(cropped_pil_image[1]) * 255).to(torch.uint8)))
        self.assertTrue(torch.equal(cropped_tensor[3],
                                    (transforms.ToTensor()(cropped_pil_image[3]) * 255).to(torch.uint8)))
        self.assertTrue(torch.equal(cropped_tensor[4],
                                    (transforms.ToTensor()(cropped_pil_image[4]) * 255).to(torch.uint8)))
199
        self.assertTrue(torch.equal(img_tensor, img_tensor_clone))
200
201
202
203
        # scriptable function test
        cropped_script = script_five_crop(img_tensor, [10, 10])
        for cropped_script_img, cropped_tensor_img in zip(cropped_script, cropped_tensor):
            self.assertTrue(torch.equal(cropped_script_img, cropped_tensor_img))
204
205

    def test_ten_crop(self):
206
        script_ten_crop = torch.jit.script(F_t.ten_crop)
207
        img_tensor = torch.randint(0, 255, (1, 32, 32), dtype=torch.uint8)
208
        img_tensor_clone = img_tensor.clone()
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
        cropped_tensor = F_t.ten_crop(img_tensor, [10, 10])
        cropped_pil_image = F.ten_crop(transforms.ToPILImage()(img_tensor), [10, 10])
        self.assertTrue(torch.equal(cropped_tensor[0],
                                    (transforms.ToTensor()(cropped_pil_image[0]) * 255).to(torch.uint8)))
        self.assertTrue(torch.equal(cropped_tensor[1],
                                    (transforms.ToTensor()(cropped_pil_image[2]) * 255).to(torch.uint8)))
        self.assertTrue(torch.equal(cropped_tensor[2],
                                    (transforms.ToTensor()(cropped_pil_image[1]) * 255).to(torch.uint8)))
        self.assertTrue(torch.equal(cropped_tensor[3],
                                    (transforms.ToTensor()(cropped_pil_image[3]) * 255).to(torch.uint8)))
        self.assertTrue(torch.equal(cropped_tensor[4],
                                    (transforms.ToTensor()(cropped_pil_image[4]) * 255).to(torch.uint8)))
        self.assertTrue(torch.equal(cropped_tensor[5],
                                    (transforms.ToTensor()(cropped_pil_image[5]) * 255).to(torch.uint8)))
        self.assertTrue(torch.equal(cropped_tensor[6],
                                    (transforms.ToTensor()(cropped_pil_image[7]) * 255).to(torch.uint8)))
        self.assertTrue(torch.equal(cropped_tensor[7],
                                    (transforms.ToTensor()(cropped_pil_image[6]) * 255).to(torch.uint8)))
        self.assertTrue(torch.equal(cropped_tensor[8],
                                    (transforms.ToTensor()(cropped_pil_image[8]) * 255).to(torch.uint8)))
        self.assertTrue(torch.equal(cropped_tensor[9],
                                    (transforms.ToTensor()(cropped_pil_image[9]) * 255).to(torch.uint8)))
231
        self.assertTrue(torch.equal(img_tensor, img_tensor_clone))
232
233
234
235
        # scriptable function test
        cropped_script = script_ten_crop(img_tensor, [10, 10])
        for cropped_script_img, cropped_tensor_img in zip(cropped_script, cropped_tensor):
            self.assertTrue(torch.equal(cropped_script_img, cropped_tensor_img))
236

237
238
239

if __name__ == '__main__':
    unittest.main()