test_transforms_tensor.py 16.8 KB
Newer Older
1
2
3
import torch
from torchvision import transforms as T
from torchvision.transforms import functional as F
4

vfdev's avatar
vfdev committed
5
from PIL.Image import NEAREST, BILINEAR, BICUBIC
6
7
8
9
10

import numpy as np

import unittest

11
from common_utils import TransformsTester
12
13


14
class Tester(TransformsTester):
15

16
17
18
    def setUp(self):
        self.device = "cpu"

19
    def _test_functional_op(self, func, fn_kwargs):
20
21
        if fn_kwargs is None:
            fn_kwargs = {}
22
23

        f = getattr(F, func)
24
        tensor, pil_img = self._create_data(height=10, width=10, device=self.device)
25
26
        transformed_tensor = f(tensor, **fn_kwargs)
        transformed_pil_img = f(pil_img, **fn_kwargs)
27
28
        self.compareTensorToPIL(transformed_tensor, transformed_pil_img)

29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
    def _test_transform_vs_scripted(self, transform, s_transform, tensor):
        torch.manual_seed(12)
        out1 = transform(tensor)
        torch.manual_seed(12)
        out2 = s_transform(tensor)
        self.assertTrue(out1.equal(out2))

    def _test_transform_vs_scripted_on_batch(self, transform, s_transform, batch_tensors):
        torch.manual_seed(12)
        transformed_batch = transform(batch_tensors)

        for i in range(len(batch_tensors)):
            img_tensor = batch_tensors[i, ...]
            torch.manual_seed(12)
            transformed_img = transform(img_tensor)
            self.assertTrue(transformed_img.equal(transformed_batch[i, ...]))

        torch.manual_seed(12)
        s_transformed_batch = s_transform(batch_tensors)
        self.assertTrue(transformed_batch.equal(s_transformed_batch))

50
    def _test_class_op(self, method, meth_kwargs=None, test_exact_match=True, **match_kwargs):
51
52
        if meth_kwargs is None:
            meth_kwargs = {}
vfdev's avatar
vfdev committed
53
54
55
56
57

        # test for class interface
        f = getattr(T, method)(**meth_kwargs)
        scripted_fn = torch.jit.script(f)

58
        tensor, pil_img = self._create_data(26, 34, device=self.device)
vfdev's avatar
vfdev committed
59
60
61
62
63
        # set seed to reproduce the same transformation for tensor and PIL image
        torch.manual_seed(12)
        transformed_tensor = f(tensor)
        torch.manual_seed(12)
        transformed_pil_img = f(pil_img)
64
65
66
67
        if test_exact_match:
            self.compareTensorToPIL(transformed_tensor, transformed_pil_img, **match_kwargs)
        else:
            self.approxEqualTensorToPIL(transformed_tensor.float(), transformed_pil_img, **match_kwargs)
68

vfdev's avatar
vfdev committed
69
70
        torch.manual_seed(12)
        transformed_tensor_script = scripted_fn(tensor)
71
        self.assertTrue(transformed_tensor.equal(transformed_tensor_script))
72

73
74
75
        batch_tensors = self._create_data_batch(height=23, width=34, channels=3, num_samples=4, device=self.device)
        self._test_transform_vs_scripted_on_batch(f, scripted_fn, batch_tensors)

76
77
78
    def _test_op(self, func, method, fn_kwargs=None, meth_kwargs=None):
        self._test_functional_op(func, fn_kwargs)
        self._test_class_op(method, meth_kwargs)
79
80

    def test_random_horizontal_flip(self):
81
        self._test_op('hflip', 'RandomHorizontalFlip')
82
83

    def test_random_vertical_flip(self):
84
        self._test_op('vflip', 'RandomVerticalFlip')
85

vfdev's avatar
vfdev committed
86
87
88
    def test_color_jitter(self):

        tol = 1.0 + 1e-10
89
        for f in [0.1, 0.5, 1.0, 1.34, (0.3, 0.7), [0.4, 0.5]]:
vfdev's avatar
vfdev committed
90
91
92
93
94
            meth_kwargs = {"brightness": f}
            self._test_class_op(
                "ColorJitter", meth_kwargs=meth_kwargs, test_exact_match=False, tol=tol, agg_method="max"
            )

95
        for f in [0.2, 0.5, 1.0, 1.5, (0.3, 0.7), [0.4, 0.5]]:
vfdev's avatar
vfdev committed
96
97
98
99
100
            meth_kwargs = {"contrast": f}
            self._test_class_op(
                "ColorJitter", meth_kwargs=meth_kwargs, test_exact_match=False, tol=tol, agg_method="max"
            )

101
        for f in [0.5, 0.75, 1.0, 1.25, (0.3, 0.7), [0.3, 0.4]]:
vfdev's avatar
vfdev committed
102
103
104
105
            meth_kwargs = {"saturation": f}
            self._test_class_op(
                "ColorJitter", meth_kwargs=meth_kwargs, test_exact_match=False, tol=tol, agg_method="max"
            )
106

107
108
109
110
111
112
113
114
115
116
117
118
        for f in [0.2, 0.5, (-0.2, 0.3), [-0.4, 0.5]]:
            meth_kwargs = {"hue": f}
            self._test_class_op(
                "ColorJitter", meth_kwargs=meth_kwargs, test_exact_match=False, tol=0.1, agg_method="mean"
            )

        # All 4 parameters together
        meth_kwargs = {"brightness": 0.2, "contrast": 0.2, "saturation": 0.2, "hue": 0.2}
        self._test_class_op(
            "ColorJitter", meth_kwargs=meth_kwargs, test_exact_match=False, tol=0.1, agg_method="mean"
        )

119
120
121
    def test_pad(self):

        # Test functional.pad (PIL and Tensor) with padding as single int
122
        self._test_functional_op(
123
124
125
126
            "pad", fn_kwargs={"padding": 2, "fill": 0, "padding_mode": "constant"}
        )
        # Test functional.pad and transforms.Pad with padding as [int, ]
        fn_kwargs = meth_kwargs = {"padding": [2, ], "fill": 0, "padding_mode": "constant"}
127
        self._test_op(
128
129
130
131
            "pad", "Pad", fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs
        )
        # Test functional.pad and transforms.Pad with padding as list
        fn_kwargs = meth_kwargs = {"padding": [4, 4], "fill": 0, "padding_mode": "constant"}
132
        self._test_op(
133
134
135
136
            "pad", "Pad", fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs
        )
        # Test functional.pad and transforms.Pad with padding as tuple
        fn_kwargs = meth_kwargs = {"padding": (2, 2, 2, 2), "fill": 127, "padding_mode": "constant"}
137
        self._test_op(
138
139
140
            "pad", "Pad", fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs
        )

vfdev's avatar
vfdev committed
141
142
143
144
    def test_crop(self):
        fn_kwargs = {"top": 2, "left": 3, "height": 4, "width": 5}
        # Test transforms.RandomCrop with size and padding as tuple
        meth_kwargs = {"size": (4, 5), "padding": (4, 4), "pad_if_needed": True, }
145
        self._test_op(
vfdev's avatar
vfdev committed
146
147
148
            'crop', 'RandomCrop', fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs
        )

vfdev's avatar
vfdev committed
149
150
151
152
153
154
155
156
157
158
159
160
161
        sizes = [5, [5, ], [6, 6]]
        padding_configs = [
            {"padding_mode": "constant", "fill": 0},
            {"padding_mode": "constant", "fill": 10},
            {"padding_mode": "constant", "fill": 20},
            {"padding_mode": "edge"},
            {"padding_mode": "reflect"},
        ]

        for size in sizes:
            for padding_config in padding_configs:
                config = dict(padding_config)
                config["size"] = size
162
                self._test_class_op("RandomCrop", config)
vfdev's avatar
vfdev committed
163
164
165
166

    def test_center_crop(self):
        fn_kwargs = {"output_size": (4, 5)}
        meth_kwargs = {"size": (4, 5), }
167
        self._test_op(
vfdev's avatar
vfdev committed
168
169
170
171
            "center_crop", "CenterCrop", fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs
        )
        fn_kwargs = {"output_size": (5,)}
        meth_kwargs = {"size": (5, )}
172
        self._test_op(
vfdev's avatar
vfdev committed
173
174
            "center_crop", "CenterCrop", fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs
        )
175
        tensor = torch.randint(0, 255, (3, 10, 10), dtype=torch.uint8, device=self.device)
vfdev's avatar
vfdev committed
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
        # Test torchscript of transforms.CenterCrop with size as int
        f = T.CenterCrop(size=5)
        scripted_fn = torch.jit.script(f)
        scripted_fn(tensor)

        # Test torchscript of transforms.CenterCrop with size as [int, ]
        f = T.CenterCrop(size=[5, ])
        scripted_fn = torch.jit.script(f)
        scripted_fn(tensor)

        # Test torchscript of transforms.CenterCrop with size as tuple
        f = T.CenterCrop(size=(6, 6))
        scripted_fn = torch.jit.script(f)
        scripted_fn(tensor)

191
    def _test_op_list_output(self, func, method, out_length, fn_kwargs=None, meth_kwargs=None):
vfdev's avatar
vfdev committed
192
193
194
195
        if fn_kwargs is None:
            fn_kwargs = {}
        if meth_kwargs is None:
            meth_kwargs = {}
196
197
198
199

        fn = getattr(F, func)
        scripted_fn = torch.jit.script(fn)

200
        tensor, pil_img = self._create_data(height=20, width=20, device=self.device)
201
202
        transformed_t_list = fn(tensor, **fn_kwargs)
        transformed_p_list = fn(pil_img, **fn_kwargs)
vfdev's avatar
vfdev committed
203
204
205
206
207
208
209
210
211
212
213
214
215
        self.assertEqual(len(transformed_t_list), len(transformed_p_list))
        self.assertEqual(len(transformed_t_list), out_length)
        for transformed_tensor, transformed_pil_img in zip(transformed_t_list, transformed_p_list):
            self.compareTensorToPIL(transformed_tensor, transformed_pil_img)

        transformed_t_list_script = scripted_fn(tensor.detach().clone(), **fn_kwargs)
        self.assertEqual(len(transformed_t_list), len(transformed_t_list_script))
        self.assertEqual(len(transformed_t_list_script), out_length)
        for transformed_tensor, transformed_tensor_script in zip(transformed_t_list, transformed_t_list_script):
            self.assertTrue(transformed_tensor.equal(transformed_tensor_script),
                            msg="{} vs {}".format(transformed_tensor, transformed_tensor_script))

        # test for class interface
216
217
        fn = getattr(T, method)(**meth_kwargs)
        scripted_fn = torch.jit.script(fn)
vfdev's avatar
vfdev committed
218
219
220
        output = scripted_fn(tensor)
        self.assertEqual(len(output), len(transformed_t_list_script))

221
222
223
224
225
226
227
228
229
230
231
232
233
        # test on batch of tensors
        batch_tensors = self._create_data_batch(height=23, width=34, channels=3, num_samples=4, device=self.device)
        torch.manual_seed(12)
        transformed_batch_list = fn(batch_tensors)

        for i in range(len(batch_tensors)):
            img_tensor = batch_tensors[i, ...]
            torch.manual_seed(12)
            transformed_img_list = fn(img_tensor)
            for transformed_img, transformed_batch in zip(transformed_img_list, transformed_batch_list):
                self.assertTrue(transformed_img.equal(transformed_batch[i, ...]),
                                msg="{} vs {}".format(transformed_img, transformed_batch[i, ...]))

vfdev's avatar
vfdev committed
234
235
    def test_five_crop(self):
        fn_kwargs = meth_kwargs = {"size": (5,)}
236
        self._test_op_list_output(
vfdev's avatar
vfdev committed
237
238
239
            "five_crop", "FiveCrop", out_length=5, fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs
        )
        fn_kwargs = meth_kwargs = {"size": [5, ]}
240
        self._test_op_list_output(
vfdev's avatar
vfdev committed
241
242
243
            "five_crop", "FiveCrop", out_length=5, fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs
        )
        fn_kwargs = meth_kwargs = {"size": (4, 5)}
244
        self._test_op_list_output(
vfdev's avatar
vfdev committed
245
246
247
            "five_crop", "FiveCrop", out_length=5, fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs
        )
        fn_kwargs = meth_kwargs = {"size": [4, 5]}
248
        self._test_op_list_output(
vfdev's avatar
vfdev committed
249
250
251
252
253
            "five_crop", "FiveCrop", out_length=5, fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs
        )

    def test_ten_crop(self):
        fn_kwargs = meth_kwargs = {"size": (5,)}
254
        self._test_op_list_output(
vfdev's avatar
vfdev committed
255
256
257
            "ten_crop", "TenCrop", out_length=10, fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs
        )
        fn_kwargs = meth_kwargs = {"size": [5, ]}
258
        self._test_op_list_output(
vfdev's avatar
vfdev committed
259
260
261
            "ten_crop", "TenCrop", out_length=10, fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs
        )
        fn_kwargs = meth_kwargs = {"size": (4, 5)}
262
        self._test_op_list_output(
vfdev's avatar
vfdev committed
263
264
265
            "ten_crop", "TenCrop", out_length=10, fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs
        )
        fn_kwargs = meth_kwargs = {"size": [4, 5]}
266
        self._test_op_list_output(
vfdev's avatar
vfdev committed
267
268
269
            "ten_crop", "TenCrop", out_length=10, fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs
        )

vfdev's avatar
vfdev committed
270
    def test_resize(self):
271
        tensor, _ = self._create_data(height=34, width=36, device=self.device)
272
        batch_tensors = torch.randint(0, 255, size=(4, 3, 44, 56), dtype=torch.uint8, device=self.device)
vfdev's avatar
vfdev committed
273
274
275
276
277
278
        script_fn = torch.jit.script(F.resize)

        for dt in [None, torch.float32, torch.float64]:
            if dt is not None:
                # This is a trivial cast to float of uint8 data to test all cases
                tensor = tensor.to(dt)
279
            for size in [32, 34, [32, ], [32, 32], (32, 32), [34, 35]]:
vfdev's avatar
vfdev committed
280
281
282
283
284
285
286
287
288
289
290
291
292
                for interpolation in [BILINEAR, BICUBIC, NEAREST]:

                    resized_tensor = F.resize(tensor, size=size, interpolation=interpolation)

                    if isinstance(size, int):
                        script_size = [size, ]
                    else:
                        script_size = size

                    s_resized_tensor = script_fn(tensor, size=script_size, interpolation=interpolation)
                    self.assertTrue(s_resized_tensor.equal(resized_tensor))

                    transform = T.Resize(size=script_size, interpolation=interpolation)
293
294
295
                    s_transform = torch.jit.script(transform)
                    self._test_transform_vs_scripted(transform, s_transform, tensor)
                    self._test_transform_vs_scripted_on_batch(transform, s_transform, batch_tensors)
vfdev's avatar
vfdev committed
296

297
    def test_resized_crop(self):
298
        tensor = torch.randint(0, 255, size=(3, 44, 56), dtype=torch.uint8, device=self.device)
299
        batch_tensors = torch.randint(0, 255, size=(4, 3, 44, 56), dtype=torch.uint8, device=self.device)
300

301
302
        for scale in [(0.7, 1.2), [0.7, 1.2]]:
            for ratio in [(0.75, 1.333), [0.75, 1.333]]:
303
                for size in [(32, ), [44, ], [32, ], [32, 32], (32, 32), [44, 55]]:
304
305
306
307
308
                    for interpolation in [NEAREST, BILINEAR, BICUBIC]:
                        transform = T.RandomResizedCrop(
                            size=size, scale=scale, ratio=ratio, interpolation=interpolation
                        )
                        s_transform = torch.jit.script(transform)
309
310
                        self._test_transform_vs_scripted(transform, s_transform, tensor)
                        self._test_transform_vs_scripted_on_batch(transform, s_transform, batch_tensors)
311
312

    def test_random_affine(self):
313
        tensor = torch.randint(0, 255, size=(3, 44, 56), dtype=torch.uint8, device=self.device)
314
        batch_tensors = torch.randint(0, 255, size=(4, 3, 44, 56), dtype=torch.uint8, device=self.device)
315
316
317
318
319
320
321
322
323
324
325
326

        for shear in [15, 10.0, (5.0, 10.0), [-15, 15], [-10.0, 10.0, -11.0, 11.0]]:
            for scale in [(0.7, 1.2), [0.7, 1.2]]:
                for translate in [(0.1, 0.2), [0.2, 0.1]]:
                    for degrees in [45, 35.0, (-45, 45), [-90.0, 90.0]]:
                        for interpolation in [NEAREST, BILINEAR]:
                            transform = T.RandomAffine(
                                degrees=degrees, translate=translate,
                                scale=scale, shear=shear, resample=interpolation
                            )
                            s_transform = torch.jit.script(transform)

327
328
                            self._test_transform_vs_scripted(transform, s_transform, tensor)
                            self._test_transform_vs_scripted_on_batch(transform, s_transform, batch_tensors)
329

330
    def test_random_rotate(self):
331
        tensor = torch.randint(0, 255, size=(3, 44, 56), dtype=torch.uint8, device=self.device)
332
        batch_tensors = torch.randint(0, 255, size=(4, 3, 44, 56), dtype=torch.uint8, device=self.device)
333
334
335
336
337
338
339
340
341
342

        for center in [(0, 0), [10, 10], None, (56, 44)]:
            for expand in [True, False]:
                for degrees in [45, 35.0, (-45, 45), [-90.0, 90.0]]:
                    for interpolation in [NEAREST, BILINEAR]:
                        transform = T.RandomRotation(
                            degrees=degrees, resample=interpolation, expand=expand, center=center
                        )
                        s_transform = torch.jit.script(transform)

343
344
                        self._test_transform_vs_scripted(transform, s_transform, tensor)
                        self._test_transform_vs_scripted_on_batch(transform, s_transform, batch_tensors)
345

346
    def test_random_perspective(self):
347
        tensor = torch.randint(0, 255, size=(3, 44, 56), dtype=torch.uint8, device=self.device)
348
        batch_tensors = torch.randint(0, 255, size=(4, 3, 44, 56), dtype=torch.uint8, device=self.device)
349
350
351
352
353
354
355
356
357

        for distortion_scale in np.linspace(0.1, 1.0, num=20):
            for interpolation in [NEAREST, BILINEAR]:
                transform = T.RandomPerspective(
                    distortion_scale=distortion_scale,
                    interpolation=interpolation
                )
                s_transform = torch.jit.script(transform)

358
359
                self._test_transform_vs_scripted(transform, s_transform, tensor)
                self._test_transform_vs_scripted_on_batch(transform, s_transform, batch_tensors)
360

361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
    def test_to_grayscale(self):

        meth_kwargs = {"num_output_channels": 1}
        tol = 1.0 + 1e-10
        self._test_class_op(
            "Grayscale", meth_kwargs=meth_kwargs, test_exact_match=False, tol=tol, agg_method="max"
        )

        meth_kwargs = {"num_output_channels": 3}
        self._test_class_op(
            "Grayscale", meth_kwargs=meth_kwargs, test_exact_match=False, tol=tol, agg_method="max"
        )

        meth_kwargs = {}
        self._test_class_op(
            "RandomGrayscale", meth_kwargs=meth_kwargs, test_exact_match=False, tol=tol, agg_method="max"
        )

379

380
381
382
383
384
385
386
@unittest.skipIf(not torch.cuda.is_available(), reason="Skip if no CUDA device")
class CUDATester(Tester):

    def setUp(self):
        self.device = "cuda"


387
388
if __name__ == '__main__':
    unittest.main()