test_functional_tensor.py 9.63 KB
Newer Older
1
from __future__ import division
ekka's avatar
ekka committed
2
import torch
3
from torch import Tensor
ekka's avatar
ekka committed
4
import torchvision.transforms as transforms
5
import torchvision.transforms.functional_tensor as F_t
ekka's avatar
ekka committed
6
import torchvision.transforms.functional as F
7
import numpy as np
8
import unittest
ekka's avatar
ekka committed
9
import random
10
from torch.jit.annotations import Optional, List, BroadcastingList2, Tuple
11
12
13
14
15


class Tester(unittest.TestCase):

    def test_vflip(self):
16
        script_vflip = torch.jit.script(F_t.vflip)
17
        img_tensor = torch.randn(3, 16, 16)
18
        img_tensor_clone = img_tensor.clone()
19
20
21
22
        vflipped_img = F_t.vflip(img_tensor)
        vflipped_img_again = F_t.vflip(vflipped_img)
        self.assertEqual(vflipped_img.shape, img_tensor.shape)
        self.assertTrue(torch.equal(img_tensor, vflipped_img_again))
23
        self.assertTrue(torch.equal(img_tensor, img_tensor_clone))
24
25
26
        # scriptable function test
        vflipped_img_script = script_vflip(img_tensor)
        self.assertTrue(torch.equal(vflipped_img, vflipped_img_script))
27
28

    def test_hflip(self):
29
        script_hflip = torch.jit.script(F_t.hflip)
30
        img_tensor = torch.randn(3, 16, 16)
31
        img_tensor_clone = img_tensor.clone()
32
33
34
35
        hflipped_img = F_t.hflip(img_tensor)
        hflipped_img_again = F_t.hflip(hflipped_img)
        self.assertEqual(hflipped_img.shape, img_tensor.shape)
        self.assertTrue(torch.equal(img_tensor, hflipped_img_again))
36
        self.assertTrue(torch.equal(img_tensor, img_tensor_clone))
37
38
39
        # scriptable function test
        hflipped_img_script = script_hflip(img_tensor)
        self.assertTrue(torch.equal(hflipped_img, hflipped_img_script))
40

ekka's avatar
ekka committed
41
    def test_crop(self):
42
        script_crop = torch.jit.script(F_t.crop)
ekka's avatar
ekka committed
43
        img_tensor = torch.randint(0, 255, (3, 16, 16), dtype=torch.uint8)
44
        img_tensor_clone = img_tensor.clone()
ekka's avatar
ekka committed
45
46
47
48
49
50
51
52
        top = random.randint(0, 15)
        left = random.randint(0, 15)
        height = random.randint(1, 16 - top)
        width = random.randint(1, 16 - left)
        img_cropped = F_t.crop(img_tensor, top, left, height, width)
        img_PIL = transforms.ToPILImage()(img_tensor)
        img_PIL_cropped = F.crop(img_PIL, top, left, height, width)
        img_cropped_GT = transforms.ToTensor()(img_PIL_cropped)
53
        self.assertTrue(torch.equal(img_tensor, img_tensor_clone))
ekka's avatar
ekka committed
54
55
        self.assertTrue(torch.equal(img_cropped, (img_cropped_GT * 255).to(torch.uint8)),
                        "functional_tensor crop not working")
56
57
58
        # scriptable function test
        cropped_img_script = script_crop(img_tensor, top, left, height, width)
        self.assertTrue(torch.equal(img_cropped, cropped_img_script))
ekka's avatar
ekka committed
59

60
    def test_adjustments(self):
61
62
63
64
65
66
67
        script_adjust_brightness = torch.jit.script(F_t.adjust_brightness)
        script_adjust_contrast = torch.jit.script(F_t.adjust_contrast)
        script_adjust_saturation = torch.jit.script(F_t.adjust_saturation)

        fns = ((F.adjust_brightness, F_t.adjust_brightness, script_adjust_brightness),
               (F.adjust_contrast, F_t.adjust_contrast, script_adjust_contrast),
               (F.adjust_saturation, F_t.adjust_saturation, script_adjust_saturation))
68
69
70
71
72
73
74
75
76
77
78
79

        for _ in range(20):
            channels = 3
            dims = torch.randint(1, 50, (2,))
            shape = (channels, dims[0], dims[1])

            if torch.randint(0, 2, (1,)) == 0:
                img = torch.rand(*shape, dtype=torch.float)
            else:
                img = torch.randint(0, 256, shape, dtype=torch.uint8)

            factor = 3 * torch.rand(1)
80
            img_clone = img.clone()
81
            for f, ft, sft in fns:
82
83

                ft_img = ft(img, factor)
84
                sft_img = sft(img, factor)
85
86
                if not img.dtype.is_floating_point:
                    ft_img = ft_img.to(torch.float) / 255
87
                    sft_img = sft_img.to(torch.float) / 255
88
89
90
91
92
93
94
95

                img_pil = transforms.ToPILImage()(img)
                f_img_pil = f(img_pil, factor)
                f_img = transforms.ToTensor()(f_img_pil)

                # F uses uint8 and F_t uses float, so there is a small
                # difference in values caused by (at most 5) truncations.
                max_diff = (ft_img - f_img).abs().max()
96
                max_diff_scripted = (sft_img - f_img).abs().max()
97
                self.assertLess(max_diff, 5 / 255 + 1e-5)
98
                self.assertLess(max_diff_scripted, 5 / 255 + 1e-5)
99
                self.assertTrue(torch.equal(img, img_clone))
100

101
    def test_rgb_to_grayscale(self):
102
        script_rgb_to_grayscale = torch.jit.script(F_t.rgb_to_grayscale)
103
        img_tensor = torch.randint(0, 255, (3, 16, 16), dtype=torch.uint8)
104
        img_tensor_clone = img_tensor.clone()
105
106
107
108
        grayscale_tensor = F_t.rgb_to_grayscale(img_tensor).to(int)
        grayscale_pil_img = torch.tensor(np.array(F.to_grayscale(F.to_pil_image(img_tensor)))).to(int)
        max_diff = (grayscale_tensor - grayscale_pil_img).abs().max()
        self.assertLess(max_diff, 1.0001)
109
        self.assertTrue(torch.equal(img_tensor, img_tensor_clone))
110
111
112
        # scriptable function test
        grayscale_script = script_rgb_to_grayscale(img_tensor).to(int)
        self.assertTrue(torch.equal(grayscale_script, grayscale_tensor))
113

114
    def test_center_crop(self):
115
        script_center_crop = torch.jit.script(F_t.center_crop)
116
        img_tensor = torch.randint(0, 255, (1, 32, 32), dtype=torch.uint8)
117
        img_tensor_clone = img_tensor.clone()
118
119
120
121
        cropped_tensor = F_t.center_crop(img_tensor, [10, 10])
        cropped_pil_image = F.center_crop(transforms.ToPILImage()(img_tensor), [10, 10])
        cropped_pil_tensor = (transforms.ToTensor()(cropped_pil_image) * 255).to(torch.uint8)
        self.assertTrue(torch.equal(cropped_tensor, cropped_pil_tensor))
122
        self.assertTrue(torch.equal(img_tensor, img_tensor_clone))
123
124
125
        # scriptable function test
        cropped_script = script_center_crop(img_tensor, [10, 10])
        self.assertTrue(torch.equal(cropped_script, cropped_tensor))
126
127

    def test_five_crop(self):
128
        script_five_crop = torch.jit.script(F_t.five_crop)
129
        img_tensor = torch.randint(0, 255, (1, 32, 32), dtype=torch.uint8)
130
        img_tensor_clone = img_tensor.clone()
131
132
133
134
135
136
137
138
139
140
141
142
        cropped_tensor = F_t.five_crop(img_tensor, [10, 10])
        cropped_pil_image = F.five_crop(transforms.ToPILImage()(img_tensor), [10, 10])
        self.assertTrue(torch.equal(cropped_tensor[0],
                                    (transforms.ToTensor()(cropped_pil_image[0]) * 255).to(torch.uint8)))
        self.assertTrue(torch.equal(cropped_tensor[1],
                                    (transforms.ToTensor()(cropped_pil_image[2]) * 255).to(torch.uint8)))
        self.assertTrue(torch.equal(cropped_tensor[2],
                                    (transforms.ToTensor()(cropped_pil_image[1]) * 255).to(torch.uint8)))
        self.assertTrue(torch.equal(cropped_tensor[3],
                                    (transforms.ToTensor()(cropped_pil_image[3]) * 255).to(torch.uint8)))
        self.assertTrue(torch.equal(cropped_tensor[4],
                                    (transforms.ToTensor()(cropped_pil_image[4]) * 255).to(torch.uint8)))
143
        self.assertTrue(torch.equal(img_tensor, img_tensor_clone))
144
145
146
147
        # scriptable function test
        cropped_script = script_five_crop(img_tensor, [10, 10])
        for cropped_script_img, cropped_tensor_img in zip(cropped_script, cropped_tensor):
            self.assertTrue(torch.equal(cropped_script_img, cropped_tensor_img))
148
149

    def test_ten_crop(self):
150
        script_ten_crop = torch.jit.script(F_t.ten_crop)
151
        img_tensor = torch.randint(0, 255, (1, 32, 32), dtype=torch.uint8)
152
        img_tensor_clone = img_tensor.clone()
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
        cropped_tensor = F_t.ten_crop(img_tensor, [10, 10])
        cropped_pil_image = F.ten_crop(transforms.ToPILImage()(img_tensor), [10, 10])
        self.assertTrue(torch.equal(cropped_tensor[0],
                                    (transforms.ToTensor()(cropped_pil_image[0]) * 255).to(torch.uint8)))
        self.assertTrue(torch.equal(cropped_tensor[1],
                                    (transforms.ToTensor()(cropped_pil_image[2]) * 255).to(torch.uint8)))
        self.assertTrue(torch.equal(cropped_tensor[2],
                                    (transforms.ToTensor()(cropped_pil_image[1]) * 255).to(torch.uint8)))
        self.assertTrue(torch.equal(cropped_tensor[3],
                                    (transforms.ToTensor()(cropped_pil_image[3]) * 255).to(torch.uint8)))
        self.assertTrue(torch.equal(cropped_tensor[4],
                                    (transforms.ToTensor()(cropped_pil_image[4]) * 255).to(torch.uint8)))
        self.assertTrue(torch.equal(cropped_tensor[5],
                                    (transforms.ToTensor()(cropped_pil_image[5]) * 255).to(torch.uint8)))
        self.assertTrue(torch.equal(cropped_tensor[6],
                                    (transforms.ToTensor()(cropped_pil_image[7]) * 255).to(torch.uint8)))
        self.assertTrue(torch.equal(cropped_tensor[7],
                                    (transforms.ToTensor()(cropped_pil_image[6]) * 255).to(torch.uint8)))
        self.assertTrue(torch.equal(cropped_tensor[8],
                                    (transforms.ToTensor()(cropped_pil_image[8]) * 255).to(torch.uint8)))
        self.assertTrue(torch.equal(cropped_tensor[9],
                                    (transforms.ToTensor()(cropped_pil_image[9]) * 255).to(torch.uint8)))
175
        self.assertTrue(torch.equal(img_tensor, img_tensor_clone))
176
177
178
179
        # scriptable function test
        cropped_script = script_ten_crop(img_tensor, [10, 10])
        for cropped_script_img, cropped_tensor_img in zip(cropped_script, cropped_tensor):
            self.assertTrue(torch.equal(cropped_script_img, cropped_tensor_img))
180

181
182
183

if __name__ == '__main__':
    unittest.main()