test_transforms.py 59.3 KB
Newer Older
1
from __future__ import division
2
import os
Philip Meier's avatar
Philip Meier committed
3
import mock
4
5
import torch
import torchvision.transforms as transforms
6
import torchvision.transforms.functional as F
7
from torch._utils_internal import get_file_path_2
8
import unittest
9
import math
10
import random
11
import numpy as np
12
13
14
15
16
17
from PIL import Image
try:
    import accimage
except ImportError:
    accimage = None

18
19
20
21
22
try:
    from scipy import stats
except ImportError:
    stats = None

23
24
GRACE_HOPPER = get_file_path_2(
    os.path.dirname(os.path.abspath(__file__)), 'assets', 'grace_hopper_517x606.jpg')
25

26

27
class Tester(unittest.TestCase):
28

29
30
31
32
    def test_crop(self):
        height = random.randint(10, 32) * 2
        width = random.randint(10, 32) * 2
        oheight = random.randint(5, (height - 2) / 2) * 2
33
34
        owidth = random.randint(5, (width - 2) / 2) * 2

35
        img = torch.ones(3, height, width)
36
37
38
        oh1 = (height - oheight) // 2
        ow1 = (width - owidth) // 2
        imgnarrow = img[:, oh1:oh1 + oheight, ow1:ow1 + owidth]
39
40
41
42
43
44
        imgnarrow.fill_(0)
        result = transforms.Compose([
            transforms.ToPILImage(),
            transforms.CenterCrop((oheight, owidth)),
            transforms.ToTensor(),
        ])(img)
45
46
        self.assertEqual(result.sum(), 0,
                         "height: {} width: {} oheight: {} owdith: {}".format(height, width, oheight, owidth))
47
48
49
50
51
52
53
54
        oheight += 1
        owidth += 1
        result = transforms.Compose([
            transforms.ToPILImage(),
            transforms.CenterCrop((oheight, owidth)),
            transforms.ToTensor(),
        ])(img)
        sum1 = result.sum()
55
56
        self.assertGreater(sum1, 1,
                           "height: {} width: {} oheight: {} owdith: {}".format(height, width, oheight, owidth))
57
        oheight += 1
58
        owidth += 1
59
60
61
62
63
64
        result = transforms.Compose([
            transforms.ToPILImage(),
            transforms.CenterCrop((oheight, owidth)),
            transforms.ToTensor(),
        ])(img)
        sum2 = result.sum()
65
66
67
68
        self.assertGreater(sum2, 0,
                           "height: {} width: {} oheight: {} owdith: {}".format(height, width, oheight, owidth))
        self.assertGreater(sum2, sum1,
                           "height: {} width: {} oheight: {} owdith: {}".format(height, width, oheight, owidth))
69

70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
    def test_five_crop(self):
        to_pil_image = transforms.ToPILImage()
        h = random.randint(5, 25)
        w = random.randint(5, 25)
        for single_dim in [True, False]:
            crop_h = random.randint(1, h)
            crop_w = random.randint(1, w)
            if single_dim:
                crop_h = min(crop_h, crop_w)
                crop_w = crop_h
                transform = transforms.FiveCrop(crop_h)
            else:
                transform = transforms.FiveCrop((crop_h, crop_w))

            img = torch.FloatTensor(3, h, w).uniform_()
            results = transform(to_pil_image(img))

87
            self.assertEqual(len(results), 5)
88
            for crop in results:
89
                self.assertEqual(crop.size, (crop_w, crop_h))
90
91
92
93
94
95
96
97

            to_pil_image = transforms.ToPILImage()
            tl = to_pil_image(img[:, 0:crop_h, 0:crop_w])
            tr = to_pil_image(img[:, 0:crop_h, w - crop_w:])
            bl = to_pil_image(img[:, h - crop_h:, 0:crop_w])
            br = to_pil_image(img[:, h - crop_h:, w - crop_w:])
            center = transforms.CenterCrop((crop_h, crop_w))(to_pil_image(img))
            expected_output = (tl, tr, bl, br, center)
98
            self.assertEqual(results, expected_output)
99
100
101
102
103
104
105
106
107
108
109
110

    def test_ten_crop(self):
        to_pil_image = transforms.ToPILImage()
        h = random.randint(5, 25)
        w = random.randint(5, 25)
        for should_vflip in [True, False]:
            for single_dim in [True, False]:
                crop_h = random.randint(1, h)
                crop_w = random.randint(1, w)
                if single_dim:
                    crop_h = min(crop_h, crop_w)
                    crop_w = crop_h
111
112
                    transform = transforms.TenCrop(crop_h,
                                                   vertical_flip=should_vflip)
113
114
                    five_crop = transforms.FiveCrop(crop_h)
                else:
115
116
                    transform = transforms.TenCrop((crop_h, crop_w),
                                                   vertical_flip=should_vflip)
117
118
119
120
121
                    five_crop = transforms.FiveCrop((crop_h, crop_w))

                img = to_pil_image(torch.FloatTensor(3, h, w).uniform_())
                results = transform(img)
                expected_output = five_crop(img)
122
123
124
125
126

                # Checking if FiveCrop and TenCrop can be printed as string
                transform.__repr__()
                five_crop.__repr__()

127
128
129
130
131
132
133
                if should_vflip:
                    vflipped_img = img.transpose(Image.FLIP_TOP_BOTTOM)
                    expected_output += five_crop(vflipped_img)
                else:
                    hflipped_img = img.transpose(Image.FLIP_LEFT_RIGHT)
                    expected_output += five_crop(hflipped_img)

134
135
                self.assertEqual(len(results), 10)
                self.assertEqual(results, expected_output)
136

137
138
139
140
141
142
143
144
    def test_randomresized_params(self):
        height = random.randint(24, 32) * 2
        width = random.randint(24, 32) * 2
        img = torch.ones(3, height, width)
        to_pil_image = transforms.ToPILImage()
        img = to_pil_image(img)
        size = 100
        epsilon = 0.05
145
        min_scale = 0.25
Francisco Massa's avatar
Francisco Massa committed
146
        for _ in range(10):
147
            scale_min = max(round(random.random(), 2), min_scale)
148
            scale_range = (scale_min, scale_min + round(random.random(), 2))
149
            aspect_min = max(round(random.random(), 2), epsilon)
150
151
            aspect_ratio_range = (aspect_min, aspect_min + round(random.random(), 2))
            randresizecrop = transforms.RandomResizedCrop(size, scale_range, aspect_ratio_range)
152
            i, j, h, w = randresizecrop.get_params(img, scale_range, aspect_ratio_range)
153
            aspect_ratio_obtained = w / h
154
155
156
157
158
159
160
            self.assertTrue((min(aspect_ratio_range) - epsilon <= aspect_ratio_obtained and
                             aspect_ratio_obtained <= max(aspect_ratio_range) + epsilon) or
                            aspect_ratio_obtained == 1.0)
            self.assertIsInstance(i, int)
            self.assertIsInstance(j, int)
            self.assertIsInstance(h, int)
            self.assertIsInstance(w, int)
161

162
    def test_randomperspective(self):
Francisco Massa's avatar
Francisco Massa committed
163
        for _ in range(10):
164
165
166
167
168
169
170
171
172
173
            height = random.randint(24, 32) * 2
            width = random.randint(24, 32) * 2
            img = torch.ones(3, height, width)
            to_pil_image = transforms.ToPILImage()
            img = to_pil_image(img)
            perp = transforms.RandomPerspective()
            startpoints, endpoints = perp.get_params(width, height, 0.5)
            tr_img = F.perspective(img, startpoints, endpoints)
            tr_img2 = F.to_tensor(F.perspective(tr_img, endpoints, startpoints))
            tr_img = F.to_tensor(tr_img)
174
175
176
177
            self.assertEqual(img.size[0], width)
            self.assertEqual(img.size[1], height)
            self.assertGreater(torch.nn.functional.mse_loss(tr_img, F.to_tensor(img)) + 0.3,
                               torch.nn.functional.mse_loss(tr_img2, F.to_tensor(img)))
178

179
    def test_resize(self):
180
181
182
        height = random.randint(24, 32) * 2
        width = random.randint(24, 32) * 2
        osize = random.randint(5, 12) * 2
183

184
185
186
        img = torch.ones(3, height, width)
        result = transforms.Compose([
            transforms.ToPILImage(),
187
            transforms.Resize(osize),
188
189
            transforms.ToTensor(),
        ])(img)
190
        self.assertIn(osize, result.size())
191
        if height < width:
192
            self.assertLessEqual(result.size(1), result.size(2))
193
        elif width < height:
194
            self.assertGreaterEqual(result.size(1), result.size(2))
195

196
197
        result = transforms.Compose([
            transforms.ToPILImage(),
198
            transforms.Resize([osize, osize]),
199
200
            transforms.ToTensor(),
        ])(img)
201
202
203
        self.assertIn(osize, result.size())
        self.assertEqual(result.size(1), osize)
        self.assertEqual(result.size(2), osize)
204

205
206
207
208
        oheight = random.randint(5, 12) * 2
        owidth = random.randint(5, 12) * 2
        result = transforms.Compose([
            transforms.ToPILImage(),
209
            transforms.Resize((oheight, owidth)),
210
211
            transforms.ToTensor(),
        ])(img)
212
213
        self.assertEqual(result.size(1), oheight)
        self.assertEqual(result.size(2), owidth)
214
215
216

        result = transforms.Compose([
            transforms.ToPILImage(),
217
            transforms.Resize([oheight, owidth]),
218
219
            transforms.ToTensor(),
        ])(img)
220
221
        self.assertEqual(result.size(1), oheight)
        self.assertEqual(result.size(2), owidth)
222

223
224
225
226
    def test_random_crop(self):
        height = random.randint(10, 32) * 2
        width = random.randint(10, 32) * 2
        oheight = random.randint(5, (height - 2) / 2) * 2
227
        owidth = random.randint(5, (width - 2) / 2) * 2
228
229
230
231
232
233
        img = torch.ones(3, height, width)
        result = transforms.Compose([
            transforms.ToPILImage(),
            transforms.RandomCrop((oheight, owidth)),
            transforms.ToTensor(),
        ])(img)
234
235
        self.assertEqual(result.size(1), oheight)
        self.assertEqual(result.size(2), owidth)
236

237
238
239
240
241
242
        padding = random.randint(1, 20)
        result = transforms.Compose([
            transforms.ToPILImage(),
            transforms.RandomCrop((oheight, owidth), padding=padding),
            transforms.ToTensor(),
        ])(img)
243
244
        self.assertEqual(result.size(1), oheight)
        self.assertEqual(result.size(2), owidth)
245

246
247
248
249
250
        result = transforms.Compose([
            transforms.ToPILImage(),
            transforms.RandomCrop((height, width)),
            transforms.ToTensor()
        ])(img)
251
252
253
        self.assertEqual(result.size(1), height)
        self.assertEqual(result.size(2), width)
        self.assertTrue(np.allclose(img.numpy(), result.numpy()))
254

255
256
257
258
259
        result = transforms.Compose([
            transforms.ToPILImage(),
            transforms.RandomCrop((height + 1, width + 1), pad_if_needed=True),
            transforms.ToTensor(),
        ])(img)
260
261
        self.assertEqual(result.size(1), height + 1)
        self.assertEqual(result.size(2), width + 1)
262

263
264
265
266
267
268
269
270
271
272
    def test_pad(self):
        height = random.randint(10, 32) * 2
        width = random.randint(10, 32) * 2
        img = torch.ones(3, height, width)
        padding = random.randint(1, 20)
        result = transforms.Compose([
            transforms.ToPILImage(),
            transforms.Pad(padding),
            transforms.ToTensor(),
        ])(img)
273
274
        self.assertEqual(result.size(1), height + 2 * padding)
        self.assertEqual(result.size(2), width + 2 * padding)
Soumith Chintala's avatar
Soumith Chintala committed
275

276
277
278
279
280
281
282
    def test_pad_with_tuple_of_pad_values(self):
        height = random.randint(10, 32) * 2
        width = random.randint(10, 32) * 2
        img = transforms.ToPILImage()(torch.ones(3, height, width))

        padding = tuple([random.randint(1, 20) for _ in range(2)])
        output = transforms.Pad(padding)(img)
283
        self.assertEqual(output.size, (width + padding[0] * 2, height + padding[1] * 2))
284
285
286

        padding = tuple([random.randint(1, 20) for _ in range(4)])
        output = transforms.Pad(padding)(img)
287
288
        self.assertEqual(output.size[0], width + padding[0] + padding[2])
        self.assertEqual(output.size[1], height + padding[1] + padding[3])
289

290
291
292
        # Checking if Padding can be printed as string
        transforms.Pad(padding).__repr__()

293
294
    def test_pad_with_non_constant_padding_modes(self):
        """Unit tests for edge, reflect, symmetric padding"""
vfdev's avatar
vfdev committed
295
        img = torch.zeros(3, 27, 27).byte()
296
297
298
299
300
301
302
303
304
        img[:, :, 0] = 1  # Constant value added to leftmost edge
        img = transforms.ToPILImage()(img)
        img = F.pad(img, 1, (200, 200, 200))

        # pad 3 to all sidess
        edge_padded_img = F.pad(img, 3, padding_mode='edge')
        # First 6 elements of leftmost edge in the middle of the image, values are in order:
        # edge_pad, edge_pad, edge_pad, constant_pad, constant value added to leftmost edge, 0
        edge_middle_slice = np.asarray(edge_padded_img).transpose(2, 0, 1)[0][17][:6]
305
306
        self.assertTrue(np.all(edge_middle_slice == np.asarray([200, 200, 200, 200, 1, 0])))
        self.assertEqual(transforms.ToTensor()(edge_padded_img).size(), (3, 35, 35))
307
308
309
310
311
312

        # Pad 3 to left/right, 2 to top/bottom
        reflect_padded_img = F.pad(img, (3, 2), padding_mode='reflect')
        # First 6 elements of leftmost edge in the middle of the image, values are in order:
        # reflect_pad, reflect_pad, reflect_pad, constant_pad, constant value added to leftmost edge, 0
        reflect_middle_slice = np.asarray(reflect_padded_img).transpose(2, 0, 1)[0][17][:6]
313
314
        self.assertTrue(np.all(reflect_middle_slice == np.asarray([0, 0, 1, 200, 1, 0])))
        self.assertEqual(transforms.ToTensor()(reflect_padded_img).size(), (3, 33, 35))
315
316
317
318
319
320

        # Pad 3 to left, 2 to top, 2 to right, 1 to bottom
        symmetric_padded_img = F.pad(img, (3, 2, 2, 1), padding_mode='symmetric')
        # First 6 elements of leftmost edge in the middle of the image, values are in order:
        # sym_pad, sym_pad, sym_pad, constant_pad, constant value added to leftmost edge, 0
        symmetric_middle_slice = np.asarray(symmetric_padded_img).transpose(2, 0, 1)[0][17][:6]
321
322
        self.assertTrue(np.all(symmetric_middle_slice == np.asarray([0, 1, 200, 200, 1, 0])))
        self.assertEqual(transforms.ToTensor()(symmetric_padded_img).size(), (3, 32, 34))
323

324
    def test_pad_raises_with_invalid_pad_sequence_len(self):
325
326
327
328
329
330
331
332
333
        with self.assertRaises(ValueError):
            transforms.Pad(())

        with self.assertRaises(ValueError):
            transforms.Pad((1, 2, 3))

        with self.assertRaises(ValueError):
            transforms.Pad((1, 2, 3, 4, 5))

Soumith Chintala's avatar
Soumith Chintala committed
334
335
336
337
    def test_lambda(self):
        trans = transforms.Lambda(lambda x: x.add(10))
        x = torch.randn(10)
        y = trans(x)
338
        self.assertTrue(y.equal(torch.add(x, 10)))
Soumith Chintala's avatar
Soumith Chintala committed
339
340
341
342

        trans = transforms.Lambda(lambda x: x.add_(10))
        x = torch.randn(10)
        y = trans(x)
343
        self.assertTrue(y.equal(x))
344

345
346
347
        # Checking if Lambda can be printed as string
        trans.__repr__()

348
    @unittest.skipIf(stats is None, 'scipy.stats not available')
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
    def test_random_apply(self):
        random_state = random.getstate()
        random.seed(42)
        random_apply_transform = transforms.RandomApply(
            [
                transforms.RandomRotation((-45, 45)),
                transforms.RandomHorizontalFlip(),
                transforms.RandomVerticalFlip(),
            ], p=0.75
        )
        img = transforms.ToPILImage()(torch.rand(3, 10, 10))
        num_samples = 250
        num_applies = 0
        for _ in range(num_samples):
            out = random_apply_transform(img)
            if out != img:
                num_applies += 1

        p_value = stats.binom_test(num_applies, num_samples, p=0.75)
        random.setstate(random_state)
369
        self.assertGreater(p_value, 0.0001)
370
371
372
373

        # Checking if RandomApply can be printed as string
        random_apply_transform.__repr__()

374
    @unittest.skipIf(stats is None, 'scipy.stats not available')
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
    def test_random_choice(self):
        random_state = random.getstate()
        random.seed(42)
        random_choice_transform = transforms.RandomChoice(
            [
                transforms.Resize(15),
                transforms.Resize(20),
                transforms.CenterCrop(10)
            ]
        )
        img = transforms.ToPILImage()(torch.rand(3, 25, 25))
        num_samples = 250
        num_resize_15 = 0
        num_resize_20 = 0
        num_crop_10 = 0
        for _ in range(num_samples):
            out = random_choice_transform(img)
            if out.size == (15, 15):
                num_resize_15 += 1
            elif out.size == (20, 20):
                num_resize_20 += 1
            elif out.size == (10, 10):
                num_crop_10 += 1

        p_value = stats.binom_test(num_resize_15, num_samples, p=0.33333)
400
        self.assertGreater(p_value, 0.0001)
401
        p_value = stats.binom_test(num_resize_20, num_samples, p=0.33333)
402
        self.assertGreater(p_value, 0.0001)
403
        p_value = stats.binom_test(num_crop_10, num_samples, p=0.33333)
404
        self.assertGreater(p_value, 0.0001)
405
406
407
408
409

        random.setstate(random_state)
        # Checking if RandomChoice can be printed as string
        random_choice_transform.__repr__()

410
    @unittest.skipIf(stats is None, 'scipy.stats not available')
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
    def test_random_order(self):
        random_state = random.getstate()
        random.seed(42)
        random_order_transform = transforms.RandomOrder(
            [
                transforms.Resize(20),
                transforms.CenterCrop(10)
            ]
        )
        img = transforms.ToPILImage()(torch.rand(3, 25, 25))
        num_samples = 250
        num_normal_order = 0
        resize_crop_out = transforms.CenterCrop(10)(transforms.Resize(20)(img))
        for _ in range(num_samples):
            out = random_order_transform(img)
            if out == resize_crop_out:
                num_normal_order += 1

        p_value = stats.binom_test(num_normal_order, num_samples, p=0.5)
        random.setstate(random_state)
431
        self.assertGreater(p_value, 0.0001)
432
433
434
435

        # Checking if RandomOrder can be printed as string
        random_order_transform.__repr__()

436
    def test_to_tensor(self):
437
        test_channels = [1, 3, 4]
438
439
        height, width = 4, 4
        trans = transforms.ToTensor()
440

441
442
443
444
445
446
447
        with self.assertRaises(TypeError):
            trans(np.random.rand(1, height, width).tolist())

        with self.assertRaises(ValueError):
            trans(np.random.rand(height))
            trans(np.random.rand(1, 1, height, width))

448
449
450
451
        for channels in test_channels:
            input_data = torch.ByteTensor(channels, height, width).random_(0, 255).float().div_(255)
            img = transforms.ToPILImage()(input_data)
            output = trans(img)
452
            self.assertTrue(np.allclose(input_data.numpy(), output.numpy()))
453

454
            ndarray = np.random.randint(low=0, high=255, size=(height, width, channels)).astype(np.uint8)
455
456
            output = trans(ndarray)
            expected_output = ndarray.transpose((2, 0, 1)) / 255.0
457
            self.assertTrue(np.allclose(output.numpy(), expected_output))
458

459
460
461
            ndarray = np.random.rand(height, width, channels).astype(np.float32)
            output = trans(ndarray)
            expected_output = ndarray.transpose((2, 0, 1))
462
            self.assertTrue(np.allclose(output.numpy(), expected_output))
463

464
465
466
467
        # separate test for mode '1' PIL images
        input_data = torch.ByteTensor(1, height, width).bernoulli_()
        img = transforms.ToPILImage()(input_data.mul(255)).convert('1')
        output = trans(img)
468
        self.assertTrue(np.allclose(input_data.numpy(), output.numpy()))
469

470
471
472
473
474
475
476
477
    @unittest.skipIf(accimage is None, 'accimage not available')
    def test_accimage_to_tensor(self):
        trans = transforms.ToTensor()

        expected_output = trans(Image.open(GRACE_HOPPER).convert('RGB'))
        output = trans(accimage.Image(GRACE_HOPPER))

        self.assertEqual(expected_output.size(), output.size())
478
        self.assertTrue(np.allclose(output.numpy(), expected_output.numpy()))
479
480
481
482

    @unittest.skipIf(accimage is None, 'accimage not available')
    def test_accimage_resize(self):
        trans = transforms.Compose([
483
            transforms.Resize(256, interpolation=Image.LINEAR),
484
485
486
            transforms.ToTensor(),
        ])

487
488
489
        # Checking if Compose, Resize and ToTensor can be printed as string
        trans.__repr__()

490
491
492
493
494
495
496
        expected_output = trans(Image.open(GRACE_HOPPER).convert('RGB'))
        output = trans(accimage.Image(GRACE_HOPPER))

        self.assertEqual(expected_output.size(), output.size())
        self.assertLess(np.abs((expected_output - output).mean()), 1e-3)
        self.assertLess((expected_output - output).var(), 1e-5)
        # note the high absolute tolerance
497
        self.assertTrue(np.allclose(output.numpy(), expected_output.numpy(), atol=5e-2))
498
499
500
501
502
503
504
505

    @unittest.skipIf(accimage is None, 'accimage not available')
    def test_accimage_crop(self):
        trans = transforms.Compose([
            transforms.CenterCrop(256),
            transforms.ToTensor(),
        ])

506
507
508
        # Checking if Compose, CenterCrop and ToTensor can be printed as string
        trans.__repr__()

509
510
511
512
        expected_output = trans(Image.open(GRACE_HOPPER).convert('RGB'))
        output = trans(accimage.Image(GRACE_HOPPER))

        self.assertEqual(expected_output.size(), output.size())
513
        self.assertTrue(np.allclose(output.numpy(), expected_output.numpy()))
514

515
    def test_1_channel_tensor_to_pil_image(self):
516
517
        to_tensor = transforms.ToTensor()

518
        img_data_float = torch.Tensor(1, 4, 4).uniform_()
519
520
521
522
        img_data_byte = torch.ByteTensor(1, 4, 4).random_(0, 255)
        img_data_short = torch.ShortTensor(1, 4, 4).random_()
        img_data_int = torch.IntTensor(1, 4, 4).random_()

523
524
525
526
527
528
529
530
531
532
        inputs = [img_data_float, img_data_byte, img_data_short, img_data_int]
        expected_outputs = [img_data_float.mul(255).int().float().div(255).numpy(),
                            img_data_byte.float().div(255.0).numpy(),
                            img_data_short.numpy(),
                            img_data_int.numpy()]
        expected_modes = ['L', 'L', 'I;16', 'I']

        for img_data, expected_output, mode in zip(inputs, expected_outputs, expected_modes):
            for transform in [transforms.ToPILImage(), transforms.ToPILImage(mode=mode)]:
                img = transform(img_data)
533
534
                self.assertEqual(img.mode, mode)
                self.assertTrue(np.allclose(expected_output, to_tensor(img).numpy()))
535
536
        # 'F' mode for torch.FloatTensor
        img_F_mode = transforms.ToPILImage(mode='F')(img_data_float)
537
538
539
        self.assertEqual(img_F_mode.mode, 'F')
        self.assertTrue(np.allclose(np.array(Image.fromarray(img_data_float.squeeze(0).numpy(), mode='F')),
                                    np.array(img_F_mode)))
540
541
542
543
544
545
546
547
548
549
550
551

    def test_1_channel_ndarray_to_pil_image(self):
        img_data_float = torch.Tensor(4, 4, 1).uniform_().numpy()
        img_data_byte = torch.ByteTensor(4, 4, 1).random_(0, 255).numpy()
        img_data_short = torch.ShortTensor(4, 4, 1).random_().numpy()
        img_data_int = torch.IntTensor(4, 4, 1).random_().numpy()

        inputs = [img_data_float, img_data_byte, img_data_short, img_data_int]
        expected_modes = ['F', 'L', 'I;16', 'I']
        for img_data, mode in zip(inputs, expected_modes):
            for transform in [transforms.ToPILImage(), transforms.ToPILImage(mode=mode)]:
                img = transform(img_data)
552
553
                self.assertEqual(img.mode, mode)
                self.assertTrue(np.allclose(img_data[:, :, 0], img))
554

surgan12's avatar
surgan12 committed
555
556
557
558
    def test_2_channel_ndarray_to_pil_image(self):
        def verify_img_data(img_data, mode):
            if mode is None:
                img = transforms.ToPILImage()(img_data)
559
                self.assertEqual(img.mode, 'LA')  # default should assume LA
surgan12's avatar
surgan12 committed
560
561
            else:
                img = transforms.ToPILImage(mode=mode)(img_data)
562
                self.assertEqual(img.mode, mode)
surgan12's avatar
surgan12 committed
563
564
            split = img.split()
            for i in range(2):
565
                self.assertTrue(np.allclose(img_data[:, :, i], split[i]))
surgan12's avatar
surgan12 committed
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582

        img_data = torch.ByteTensor(4, 4, 2).random_(0, 255).numpy()
        for mode in [None, 'LA']:
            verify_img_data(img_data, mode)

        transforms.ToPILImage().__repr__()

        with self.assertRaises(ValueError):
            # should raise if we try a mode for 4 or 1 or 3 channel images
            transforms.ToPILImage(mode='RGBA')(img_data)
            transforms.ToPILImage(mode='P')(img_data)
            transforms.ToPILImage(mode='RGB')(img_data)

    def test_2_channel_tensor_to_pil_image(self):
        def verify_img_data(img_data, expected_output, mode):
            if mode is None:
                img = transforms.ToPILImage()(img_data)
583
                self.assertEqual(img.mode, 'LA')  # default should assume LA
surgan12's avatar
surgan12 committed
584
585
            else:
                img = transforms.ToPILImage(mode=mode)(img_data)
586
                self.assertEqual(img.mode, mode)
surgan12's avatar
surgan12 committed
587
588
            split = img.split()
            for i in range(2):
589
                self.assertTrue(np.allclose(expected_output[i].numpy(), F.to_tensor(split[i]).numpy()))
surgan12's avatar
surgan12 committed
590
591
592
593
594
595
596
597
598
599
600
601

        img_data = torch.Tensor(2, 4, 4).uniform_()
        expected_output = img_data.mul(255).int().float().div(255)
        for mode in [None, 'LA']:
            verify_img_data(img_data, expected_output, mode=mode)

        with self.assertRaises(ValueError):
            # should raise if we try a mode for 4 or 1 or 3 channel images
            transforms.ToPILImage(mode='RGBA')(img_data)
            transforms.ToPILImage(mode='P')(img_data)
            transforms.ToPILImage(mode='RGB')(img_data)

602
603
604
605
    def test_3_channel_tensor_to_pil_image(self):
        def verify_img_data(img_data, expected_output, mode):
            if mode is None:
                img = transforms.ToPILImage()(img_data)
606
                self.assertEqual(img.mode, 'RGB')  # default should assume RGB
607
608
            else:
                img = transforms.ToPILImage(mode=mode)(img_data)
609
                self.assertEqual(img.mode, mode)
610
611
            split = img.split()
            for i in range(3):
612
                self.assertTrue(np.allclose(expected_output[i].numpy(), F.to_tensor(split[i]).numpy()))
613

614
615
616
617
        img_data = torch.Tensor(3, 4, 4).uniform_()
        expected_output = img_data.mul(255).int().float().div(255)
        for mode in [None, 'RGB', 'HSV', 'YCbCr']:
            verify_img_data(img_data, expected_output, mode=mode)
618

619
        with self.assertRaises(ValueError):
surgan12's avatar
surgan12 committed
620
            # should raise if we try a mode for 4 or 1 or 2 channel images
621
622
            transforms.ToPILImage(mode='RGBA')(img_data)
            transforms.ToPILImage(mode='P')(img_data)
surgan12's avatar
surgan12 committed
623
            transforms.ToPILImage(mode='LA')(img_data)
624

Varun Agrawal's avatar
Varun Agrawal committed
625
626
627
        with self.assertRaises(ValueError):
            transforms.ToPILImage()(torch.Tensor(1, 3, 4, 4).uniform_())

628
629
630
631
    def test_3_channel_ndarray_to_pil_image(self):
        def verify_img_data(img_data, mode):
            if mode is None:
                img = transforms.ToPILImage()(img_data)
632
                self.assertEqual(img.mode, 'RGB')  # default should assume RGB
633
634
            else:
                img = transforms.ToPILImage(mode=mode)(img_data)
635
                self.assertEqual(img.mode, mode)
636
637
            split = img.split()
            for i in range(3):
638
                self.assertTrue(np.allclose(img_data[:, :, i], split[i]))
639

640
641
642
643
        img_data = torch.ByteTensor(4, 4, 3).random_(0, 255).numpy()
        for mode in [None, 'RGB', 'HSV', 'YCbCr']:
            verify_img_data(img_data, mode)

644
645
646
        # Checking if ToPILImage can be printed as string
        transforms.ToPILImage().__repr__()

647
        with self.assertRaises(ValueError):
surgan12's avatar
surgan12 committed
648
            # should raise if we try a mode for 4 or 1 or 2 channel images
649
650
            transforms.ToPILImage(mode='RGBA')(img_data)
            transforms.ToPILImage(mode='P')(img_data)
surgan12's avatar
surgan12 committed
651
            transforms.ToPILImage(mode='LA')(img_data)
652
653
654
655
656

    def test_4_channel_tensor_to_pil_image(self):
        def verify_img_data(img_data, expected_output, mode):
            if mode is None:
                img = transforms.ToPILImage()(img_data)
657
                self.assertEqual(img.mode, 'RGBA')  # default should assume RGBA
658
659
            else:
                img = transforms.ToPILImage(mode=mode)(img_data)
660
                self.assertEqual(img.mode, mode)
661
662
663

            split = img.split()
            for i in range(4):
664
                self.assertTrue(np.allclose(expected_output[i].numpy(), F.to_tensor(split[i]).numpy()))
665

666
        img_data = torch.Tensor(4, 4, 4).uniform_()
667
        expected_output = img_data.mul(255).int().float().div(255)
surgan12's avatar
surgan12 committed
668
        for mode in [None, 'RGBA', 'CMYK', 'RGBX']:
669
            verify_img_data(img_data, expected_output, mode)
670

671
        with self.assertRaises(ValueError):
surgan12's avatar
surgan12 committed
672
            # should raise if we try a mode for 3 or 1 or 2 channel images
673
674
            transforms.ToPILImage(mode='RGB')(img_data)
            transforms.ToPILImage(mode='P')(img_data)
surgan12's avatar
surgan12 committed
675
            transforms.ToPILImage(mode='LA')(img_data)
676
677
678
679
680

    def test_4_channel_ndarray_to_pil_image(self):
        def verify_img_data(img_data, mode):
            if mode is None:
                img = transforms.ToPILImage()(img_data)
681
                self.assertEqual(img.mode, 'RGBA')  # default should assume RGBA
682
683
            else:
                img = transforms.ToPILImage(mode=mode)(img_data)
684
                self.assertEqual(img.mode, mode)
685
686
            split = img.split()
            for i in range(4):
687
                self.assertTrue(np.allclose(img_data[:, :, i], split[i]))
688

689
        img_data = torch.ByteTensor(4, 4, 4).random_(0, 255).numpy()
surgan12's avatar
surgan12 committed
690
        for mode in [None, 'RGBA', 'CMYK', 'RGBX']:
691
            verify_img_data(img_data, mode)
692

693
        with self.assertRaises(ValueError):
surgan12's avatar
surgan12 committed
694
            # should raise if we try a mode for 3 or 1 or 2 channel images
695
696
            transforms.ToPILImage(mode='RGB')(img_data)
            transforms.ToPILImage(mode='P')(img_data)
surgan12's avatar
surgan12 committed
697
            transforms.ToPILImage(mode='LA')(img_data)
698

Varun Agrawal's avatar
Varun Agrawal committed
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
    def test_2d_tensor_to_pil_image(self):
        to_tensor = transforms.ToTensor()

        img_data_float = torch.Tensor(4, 4).uniform_()
        img_data_byte = torch.ByteTensor(4, 4).random_(0, 255)
        img_data_short = torch.ShortTensor(4, 4).random_()
        img_data_int = torch.IntTensor(4, 4).random_()

        inputs = [img_data_float, img_data_byte, img_data_short, img_data_int]
        expected_outputs = [img_data_float.mul(255).int().float().div(255).numpy(),
                            img_data_byte.float().div(255.0).numpy(),
                            img_data_short.numpy(),
                            img_data_int.numpy()]
        expected_modes = ['L', 'L', 'I;16', 'I']

        for img_data, expected_output, mode in zip(inputs, expected_outputs, expected_modes):
            for transform in [transforms.ToPILImage(), transforms.ToPILImage(mode=mode)]:
                img = transform(img_data)
717
718
                self.assertEqual(img.mode, mode)
                self.assertTrue(np.allclose(expected_output, to_tensor(img).numpy()))
Varun Agrawal's avatar
Varun Agrawal committed
719
720
721
722
723
724
725
726
727
728
729
730

    def test_2d_ndarray_to_pil_image(self):
        img_data_float = torch.Tensor(4, 4).uniform_().numpy()
        img_data_byte = torch.ByteTensor(4, 4).random_(0, 255).numpy()
        img_data_short = torch.ShortTensor(4, 4).random_().numpy()
        img_data_int = torch.IntTensor(4, 4).random_().numpy()

        inputs = [img_data_float, img_data_byte, img_data_short, img_data_int]
        expected_modes = ['F', 'L', 'I;16', 'I']
        for img_data, mode in zip(inputs, expected_modes):
            for transform in [transforms.ToPILImage(), transforms.ToPILImage(mode=mode)]:
                img = transform(img_data)
731
732
                self.assertEqual(img.mode, mode)
                self.assertTrue(np.allclose(img_data, img))
Varun Agrawal's avatar
Varun Agrawal committed
733
734
735
736
737

    def test_tensor_bad_types_to_pil_image(self):
        with self.assertRaises(ValueError):
            transforms.ToPILImage()(torch.ones(1, 3, 4, 4))

738
    def test_ndarray_bad_types_to_pil_image(self):
739
        trans = transforms.ToPILImage()
740
        with self.assertRaises(TypeError):
741
742
743
744
745
            trans(np.ones([4, 4, 1], np.int64))
            trans(np.ones([4, 4, 1], np.uint16))
            trans(np.ones([4, 4, 1], np.uint32))
            trans(np.ones([4, 4, 1], np.float64))

Varun Agrawal's avatar
Varun Agrawal committed
746
747
748
        with self.assertRaises(ValueError):
            transforms.ToPILImage()(np.ones([1, 4, 4, 3]))

749
750
    @unittest.skipIf(stats is None, 'scipy.stats not available')
    def test_random_vertical_flip(self):
751
752
        random_state = random.getstate()
        random.seed(42)
753
754
755
        img = transforms.ToPILImage()(torch.rand(3, 10, 10))
        vimg = img.transpose(Image.FLIP_TOP_BOTTOM)

756
        num_samples = 250
757
        num_vertical = 0
758
        for _ in range(num_samples):
759
760
761
762
            out = transforms.RandomVerticalFlip()(img)
            if out == vimg:
                num_vertical += 1

763
764
        p_value = stats.binom_test(num_vertical, num_samples, p=0.5)
        random.setstate(random_state)
765
        self.assertGreater(p_value, 0.0001)
766

767
768
769
770
771
772
773
774
775
        num_samples = 250
        num_vertical = 0
        for _ in range(num_samples):
            out = transforms.RandomVerticalFlip(p=0.7)(img)
            if out == vimg:
                num_vertical += 1

        p_value = stats.binom_test(num_vertical, num_samples, p=0.7)
        random.setstate(random_state)
776
        self.assertGreater(p_value, 0.0001)
777

778
779
780
        # Checking if RandomVerticalFlip can be printed as string
        transforms.RandomVerticalFlip().__repr__()

781
782
    @unittest.skipIf(stats is None, 'scipy.stats not available')
    def test_random_horizontal_flip(self):
783
784
        random_state = random.getstate()
        random.seed(42)
785
786
787
        img = transforms.ToPILImage()(torch.rand(3, 10, 10))
        himg = img.transpose(Image.FLIP_LEFT_RIGHT)

788
        num_samples = 250
789
        num_horizontal = 0
790
        for _ in range(num_samples):
791
792
793
794
            out = transforms.RandomHorizontalFlip()(img)
            if out == himg:
                num_horizontal += 1

795
796
        p_value = stats.binom_test(num_horizontal, num_samples, p=0.5)
        random.setstate(random_state)
797
        self.assertGreater(p_value, 0.0001)
798

799
800
801
802
803
804
805
806
807
        num_samples = 250
        num_horizontal = 0
        for _ in range(num_samples):
            out = transforms.RandomHorizontalFlip(p=0.7)(img)
            if out == himg:
                num_horizontal += 1

        p_value = stats.binom_test(num_horizontal, num_samples, p=0.7)
        random.setstate(random_state)
808
        self.assertGreater(p_value, 0.0001)
809

810
811
812
        # Checking if RandomHorizontalFlip can be printed as string
        transforms.RandomHorizontalFlip().__repr__()

813
    @unittest.skipIf(stats is None, 'scipy.stats is not available')
814
815
816
817
818
819
820
821
822
823
824
825
    def test_normalize(self):
        def samples_from_standard_normal(tensor):
            p_value = stats.kstest(list(tensor.view(-1)), 'norm', args=(0, 1)).pvalue
            return p_value > 0.0001

        random_state = random.getstate()
        random.seed(42)
        for channels in [1, 3]:
            img = torch.rand(channels, 10, 10)
            mean = [img[c].mean() for c in range(channels)]
            std = [img[c].std() for c in range(channels)]
            normalized = transforms.Normalize(mean, std)(img)
826
            self.assertTrue(samples_from_standard_normal(normalized))
827
828
        random.setstate(random_state)

829
830
831
        # Checking if Normalize can be printed as string
        transforms.Normalize(mean, std).__repr__()

832
833
834
        # Checking the optional in-place behaviour
        tensor = torch.rand((1, 16, 16))
        tensor_inplace = transforms.Normalize((0.5,), (0.5,), inplace=True)(tensor)
835
        self.assertTrue(torch.equal(tensor, tensor_inplace))
836

837
838
839
840
841
842
843
844
845
    def test_normalize_different_dtype(self):
        for dtype1 in [torch.float32, torch.float64]:
            img = torch.rand(3, 10, 10, dtype=dtype1)
            for dtype2 in [torch.int64, torch.float32, torch.float64]:
                mean = torch.tensor([1, 2, 3], dtype=dtype2)
                std = torch.tensor([1, 2, 1], dtype=dtype2)
                # checks that it doesn't crash
                transforms.functional.normalize(img, mean, std)

846
847
848
849
850
851
852
    def test_adjust_brightness(self):
        x_shape = [2, 2, 3]
        x_data = [0, 5, 13, 54, 135, 226, 37, 8, 234, 90, 255, 1]
        x_np = np.array(x_data, dtype=np.uint8).reshape(x_shape)
        x_pil = Image.fromarray(x_np, mode='RGB')

        # test 0
853
        y_pil = F.adjust_brightness(x_pil, 1)
854
        y_np = np.array(y_pil)
855
        self.assertTrue(np.allclose(y_np, x_np))
856
857

        # test 1
858
        y_pil = F.adjust_brightness(x_pil, 0.5)
859
860
861
        y_np = np.array(y_pil)
        y_ans = [0, 2, 6, 27, 67, 113, 18, 4, 117, 45, 127, 0]
        y_ans = np.array(y_ans, dtype=np.uint8).reshape(x_shape)
862
        self.assertTrue(np.allclose(y_np, y_ans))
863
864

        # test 2
865
        y_pil = F.adjust_brightness(x_pil, 2)
866
867
868
        y_np = np.array(y_pil)
        y_ans = [0, 10, 26, 108, 255, 255, 74, 16, 255, 180, 255, 2]
        y_ans = np.array(y_ans, dtype=np.uint8).reshape(x_shape)
869
        self.assertTrue(np.allclose(y_np, y_ans))
870
871
872
873
874
875
876
877

    def test_adjust_contrast(self):
        x_shape = [2, 2, 3]
        x_data = [0, 5, 13, 54, 135, 226, 37, 8, 234, 90, 255, 1]
        x_np = np.array(x_data, dtype=np.uint8).reshape(x_shape)
        x_pil = Image.fromarray(x_np, mode='RGB')

        # test 0
878
        y_pil = F.adjust_contrast(x_pil, 1)
879
        y_np = np.array(y_pil)
880
        self.assertTrue(np.allclose(y_np, x_np))
881
882

        # test 1
883
        y_pil = F.adjust_contrast(x_pil, 0.5)
884
885
886
        y_np = np.array(y_pil)
        y_ans = [43, 45, 49, 70, 110, 156, 61, 47, 160, 88, 170, 43]
        y_ans = np.array(y_ans, dtype=np.uint8).reshape(x_shape)
887
        self.assertTrue(np.allclose(y_np, y_ans))
888
889

        # test 2
890
        y_pil = F.adjust_contrast(x_pil, 2)
891
892
893
        y_np = np.array(y_pil)
        y_ans = [0, 0, 0, 22, 184, 255, 0, 0, 255, 94, 255, 0]
        y_ans = np.array(y_ans, dtype=np.uint8).reshape(x_shape)
894
        self.assertTrue(np.allclose(y_np, y_ans))
895

Francisco Massa's avatar
Francisco Massa committed
896
    @unittest.skipIf(Image.__version__ >= '7', "Temporarily disabled")
897
898
899
900
901
902
903
    def test_adjust_saturation(self):
        x_shape = [2, 2, 3]
        x_data = [0, 5, 13, 54, 135, 226, 37, 8, 234, 90, 255, 1]
        x_np = np.array(x_data, dtype=np.uint8).reshape(x_shape)
        x_pil = Image.fromarray(x_np, mode='RGB')

        # test 0
904
        y_pil = F.adjust_saturation(x_pil, 1)
905
        y_np = np.array(y_pil)
906
        self.assertTrue(np.allclose(y_np, x_np))
907
908

        # test 1
909
        y_pil = F.adjust_saturation(x_pil, 0.5)
910
911
912
        y_np = np.array(y_pil)
        y_ans = [2, 4, 8, 87, 128, 173, 39, 25, 138, 133, 215, 88]
        y_ans = np.array(y_ans, dtype=np.uint8).reshape(x_shape)
913
        self.assertTrue(np.allclose(y_np, y_ans))
914
915

        # test 2
916
        y_pil = F.adjust_saturation(x_pil, 2)
917
918
919
        y_np = np.array(y_pil)
        y_ans = [0, 6, 22, 0, 149, 255, 32, 0, 255, 4, 255, 0]
        y_ans = np.array(y_ans, dtype=np.uint8).reshape(x_shape)
920
        self.assertTrue(np.allclose(y_np, y_ans))
921
922
923
924
925
926
927
928

    def test_adjust_hue(self):
        x_shape = [2, 2, 3]
        x_data = [0, 5, 13, 54, 135, 226, 37, 8, 234, 90, 255, 1]
        x_np = np.array(x_data, dtype=np.uint8).reshape(x_shape)
        x_pil = Image.fromarray(x_np, mode='RGB')

        with self.assertRaises(ValueError):
929
930
            F.adjust_hue(x_pil, -0.7)
            F.adjust_hue(x_pil, 1)
931
932
933

        # test 0: almost same as x_data but not exact.
        # probably because hsv <-> rgb floating point ops
934
        y_pil = F.adjust_hue(x_pil, 0)
935
936
937
        y_np = np.array(y_pil)
        y_ans = [0, 5, 13, 54, 139, 226, 35, 8, 234, 91, 255, 1]
        y_ans = np.array(y_ans, dtype=np.uint8).reshape(x_shape)
938
        self.assertTrue(np.allclose(y_np, y_ans))
939
940

        # test 1
941
        y_pil = F.adjust_hue(x_pil, 0.25)
942
943
944
        y_np = np.array(y_pil)
        y_ans = [13, 0, 12, 224, 54, 226, 234, 8, 99, 1, 222, 255]
        y_ans = np.array(y_ans, dtype=np.uint8).reshape(x_shape)
945
        self.assertTrue(np.allclose(y_np, y_ans))
946
947

        # test 2
948
        y_pil = F.adjust_hue(x_pil, -0.25)
949
950
951
        y_np = np.array(y_pil)
        y_ans = [0, 13, 2, 54, 226, 58, 8, 234, 152, 255, 43, 1]
        y_ans = np.array(y_ans, dtype=np.uint8).reshape(x_shape)
952
        self.assertTrue(np.allclose(y_np, y_ans))
953
954
955
956
957
958
959
960

    def test_adjust_gamma(self):
        x_shape = [2, 2, 3]
        x_data = [0, 5, 13, 54, 135, 226, 37, 8, 234, 90, 255, 1]
        x_np = np.array(x_data, dtype=np.uint8).reshape(x_shape)
        x_pil = Image.fromarray(x_np, mode='RGB')

        # test 0
961
        y_pil = F.adjust_gamma(x_pil, 1)
962
        y_np = np.array(y_pil)
963
        self.assertTrue(np.allclose(y_np, x_np))
964
965

        # test 1
966
        y_pil = F.adjust_gamma(x_pil, 0.5)
967
968
969
        y_np = np.array(y_pil)
        y_ans = [0, 35, 57, 117, 185, 240, 97, 45, 244, 151, 255, 15]
        y_ans = np.array(y_ans, dtype=np.uint8).reshape(x_shape)
970
        self.assertTrue(np.allclose(y_np, y_ans))
971
972

        # test 2
973
        y_pil = F.adjust_gamma(x_pil, 2)
974
975
976
        y_np = np.array(y_pil)
        y_ans = [0, 0, 0, 11, 71, 200, 5, 0, 214, 31, 255, 0]
        y_ans = np.array(y_ans, dtype=np.uint8).reshape(x_shape)
977
        self.assertTrue(np.allclose(y_np, y_ans))
978
979
980
981
982
983
984
985

    def test_adjusts_L_mode(self):
        x_shape = [2, 2, 3]
        x_data = [0, 5, 13, 54, 135, 226, 37, 8, 234, 90, 255, 1]
        x_np = np.array(x_data, dtype=np.uint8).reshape(x_shape)
        x_rgb = Image.fromarray(x_np, mode='RGB')

        x_l = x_rgb.convert('L')
986
987
988
989
990
        self.assertEqual(F.adjust_brightness(x_l, 2).mode, 'L')
        self.assertEqual(F.adjust_saturation(x_l, 2).mode, 'L')
        self.assertEqual(F.adjust_contrast(x_l, 2).mode, 'L')
        self.assertEqual(F.adjust_hue(x_l, 0.4).mode, 'L')
        self.assertEqual(F.adjust_gamma(x_l, 0.5).mode, 'L')
991
992
993
994
995
996
997
998
999
1000
1001
1002

    def test_color_jitter(self):
        color_jitter = transforms.ColorJitter(2, 2, 2, 0.1)

        x_shape = [2, 2, 3]
        x_data = [0, 5, 13, 54, 135, 226, 37, 8, 234, 90, 255, 1]
        x_np = np.array(x_data, dtype=np.uint8).reshape(x_shape)
        x_pil = Image.fromarray(x_np, mode='RGB')
        x_pil_2 = x_pil.convert('L')

        for i in range(10):
            y_pil = color_jitter(x_pil)
1003
            self.assertEqual(y_pil.mode, x_pil.mode)
1004
1005

            y_pil_2 = color_jitter(x_pil_2)
1006
            self.assertEqual(y_pil_2.mode, x_pil_2.mode)
1007

1008
1009
1010
        # Checking if ColorJitter can be printed as string
        color_jitter.__repr__()

1011
    def test_linear_transformation(self):
ekka's avatar
ekka committed
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
        num_samples = 1000
        x = torch.randn(num_samples, 3, 10, 10)
        flat_x = x.view(x.size(0), x.size(1) * x.size(2) * x.size(3))
        # compute principal components
        sigma = torch.mm(flat_x.t(), flat_x) / flat_x.size(0)
        u, s, _ = np.linalg.svd(sigma.numpy())
        zca_epsilon = 1e-10  # avoid division by 0
        d = torch.Tensor(np.diag(1. / np.sqrt(s + zca_epsilon)))
        u = torch.Tensor(u)
        principal_components = torch.mm(torch.mm(u, d), u.t())
        mean_vector = (torch.sum(flat_x, dim=0) / flat_x.size(0))
        # initialize whitening matrix
1024
        whitening = transforms.LinearTransformation(principal_components, mean_vector)
ekka's avatar
ekka committed
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
        # estimate covariance and mean using weak law of large number
        num_features = flat_x.size(1)
        cov = 0.0
        mean = 0.0
        for i in x:
            xwhite = whitening(i)
            xwhite = xwhite.view(1, -1).numpy()
            cov += np.dot(xwhite, xwhite.T) / num_features
            mean += np.sum(xwhite) / num_features
        # if rtol for std = 1e-3 then rtol for cov = 2e-3 as std**2 = cov
1035
1036
1037
1038
        self.assertTrue(np.allclose(cov / num_samples, np.identity(1), rtol=2e-3),
                        "cov not close to 1")
        self.assertTrue(np.allclose(mean / num_samples, 0, rtol=1e-3),
                        "mean not close to 0")
ekka's avatar
ekka committed
1039

1040
        # Checking if LinearTransformation can be printed as string
ekka's avatar
ekka committed
1041
1042
        whitening.__repr__()

1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
    def test_rotate(self):
        x = np.zeros((100, 100, 3), dtype=np.uint8)
        x[40, 40] = [255, 255, 255]

        with self.assertRaises(TypeError):
            F.rotate(x, 10)

        img = F.to_pil_image(x)

        result = F.rotate(img, 45)
1053
        self.assertEqual(result.size, (100, 100))
1054
        r, c, ch = np.where(result)
1055
1056
1057
        self.assertTrue(all(x in r for x in [49, 50]))
        self.assertTrue(all(x in c for x in [36]))
        self.assertTrue(all(x in ch for x in [0, 1, 2]))
1058
1059

        result = F.rotate(img, 45, expand=True)
1060
        self.assertEqual(result.size, (142, 142))
1061
        r, c, ch = np.where(result)
1062
1063
1064
        self.assertTrue(all(x in r for x in [70, 71]))
        self.assertTrue(all(x in c for x in [57]))
        self.assertTrue(all(x in ch for x in [0, 1, 2]))
1065
1066

        result = F.rotate(img, 45, center=(40, 40))
1067
        self.assertEqual(result.size, (100, 100))
1068
        r, c, ch = np.where(result)
1069
1070
1071
        self.assertTrue(all(x in r for x in [40]))
        self.assertTrue(all(x in c for x in [40]))
        self.assertTrue(all(x in ch for x in [0, 1, 2]))
1072
1073
1074
1075

        result_a = F.rotate(img, 90)
        result_b = F.rotate(img, -270)

1076
        self.assertTrue(np.all(np.array(result_a) == np.array(result_b)))
1077

Philip Meier's avatar
Philip Meier committed
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
    def test_rotate_fill(self):
        img = F.to_pil_image(np.ones((100, 100, 3), dtype=np.uint8) * 255, "RGB")

        modes = ("L", "RGB")
        nums_bands = [len(mode) for mode in modes]
        fill = 127

        for mode, num_bands in zip(modes, nums_bands):
            img_conv = img.convert(mode)
            img_rot = F.rotate(img_conv, 45.0, fill=fill)
            pixel = img_rot.getpixel((0, 0))

            if not isinstance(pixel, tuple):
                pixel = (pixel,)
            self.assertTupleEqual(pixel, tuple([fill] * num_bands))

            for wrong_num_bands in set(nums_bands) - {num_bands}:
                with self.assertRaises(ValueError):
                    F.rotate(img_conv, 45.0, fill=tuple([fill] * wrong_num_bands))

1098
    def test_affine(self):
Francisco Massa's avatar
Francisco Massa committed
1099
        input_img = np.zeros((40, 40, 3), dtype=np.uint8)
1100
        pts = []
Francisco Massa's avatar
Francisco Massa committed
1101
1102
        cnt = [20, 20]
        for pt in [(16, 16), (20, 16), (20, 20)]:
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
            for i in range(-5, 5):
                for j in range(-5, 5):
                    input_img[pt[0] + i, pt[1] + j, :] = [255, 155, 55]
                    pts.append((pt[0] + i, pt[1] + j))
        pts = list(set(pts))

        with self.assertRaises(TypeError):
            F.affine(input_img, 10)

        pil_img = F.to_pil_image(input_img)

        def _to_3x3_inv(inv_result_matrix):
            result_matrix = np.zeros((3, 3))
            result_matrix[:2, :] = np.array(inv_result_matrix).reshape((2, 3))
            result_matrix[2, 2] = 1
            return np.linalg.inv(result_matrix)

        def _test_transformation(a, t, s, sh):
            a_rad = math.radians(a)
ptrblck's avatar
ptrblck committed
1122
            s_rad = [math.radians(sh_) for sh_ in sh]
1123
1124
1125
1126
1127
            cx, cy = cnt
            tx, ty = t
            sx, sy = s_rad
            rot = a_rad

1128
            # 1) Check transformation matrix:
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
            C = np.array([[1, 0, cx],
                          [0, 1, cy],
                          [0, 0, 1]])
            T = np.array([[1, 0, tx],
                          [0, 1, ty],
                          [0, 0, 1]])
            Cinv = np.linalg.inv(C)

            RS = np.array(
                [[s * math.cos(rot), -s * math.sin(rot), 0],
                 [s * math.sin(rot), s * math.cos(rot), 0],
                 [0, 0, 1]])

            SHx = np.array([[1, -math.tan(sx), 0],
                            [0, 1, 0],
                            [0, 0, 1]])

            SHy = np.array([[1, 0, 0],
                            [-math.tan(sy), 1, 0],
                            [0, 0, 1]])

            RSS = np.matmul(RS, np.matmul(SHy, SHx))

            true_matrix = np.matmul(T, np.matmul(C, np.matmul(RSS, Cinv)))

1154
1155
            result_matrix = _to_3x3_inv(F._get_inverse_affine_matrix(center=cnt, angle=a,
                                                                     translate=t, scale=s, shear=sh))
1156
            self.assertLess(np.sum(np.abs(true_matrix - result_matrix)), 1e-10)
1157
            # 2) Perform inverse mapping:
Francisco Massa's avatar
Francisco Massa committed
1158
            true_result = np.zeros((40, 40, 3), dtype=np.uint8)
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
            inv_true_matrix = np.linalg.inv(true_matrix)
            for y in range(true_result.shape[0]):
                for x in range(true_result.shape[1]):
                    res = np.dot(inv_true_matrix, [x, y, 1])
                    _x = int(res[0] + 0.5)
                    _y = int(res[1] + 0.5)
                    if 0 <= _x < input_img.shape[1] and 0 <= _y < input_img.shape[0]:
                        true_result[y, x, :] = input_img[_y, _x, :]

            result = F.affine(pil_img, angle=a, translate=t, scale=s, shear=sh)
1169
            self.assertEqual(result.size, pil_img.size)
1170
1171
1172
1173
            # Compute number of different pixels:
            np_result = np.array(result)
            n_diff_pixels = np.sum(np_result != true_result) / 3
            # Accept 3 wrong pixels
1174
1175
1176
            self.assertLess(n_diff_pixels, 3,
                            "a={}, t={}, s={}, sh={}\n".format(a, t, s, sh) +
                            "n diff pixels={}\n".format(np.sum(np.array(result)[:, :, 0] != true_result[:, :, 0])))
1177
1178
1179

        # Test rotation
        a = 45
ptrblck's avatar
ptrblck committed
1180
        _test_transformation(a=a, t=(0, 0), s=1.0, sh=(0.0, 0.0))
1181
1182
1183

        # Test translation
        t = [10, 15]
ptrblck's avatar
ptrblck committed
1184
        _test_transformation(a=0.0, t=t, s=1.0, sh=(0.0, 0.0))
1185
1186
1187

        # Test scale
        s = 1.2
ptrblck's avatar
ptrblck committed
1188
        _test_transformation(a=0.0, t=(0.0, 0.0), s=s, sh=(0.0, 0.0))
1189
1190

        # Test shear
ptrblck's avatar
ptrblck committed
1191
        sh = [45.0, 25.0]
1192
1193
1194
1195
1196
1197
1198
        _test_transformation(a=0.0, t=(0.0, 0.0), s=1.0, sh=sh)

        # Test rotation, scale, translation, shear
        for a in range(-90, 90, 25):
            for t1 in range(-10, 10, 5):
                for s in [0.75, 0.98, 1.0, 1.1, 1.2]:
                    for sh in range(-15, 15, 5):
ptrblck's avatar
ptrblck committed
1199
                        _test_transformation(a=a, t=(t1, t1), s=s, sh=(sh, sh))
1200

1201
1202
1203
1204
1205
1206
1207
1208
1209
    def test_random_rotation(self):

        with self.assertRaises(ValueError):
            transforms.RandomRotation(-0.7)
            transforms.RandomRotation([-0.7])
            transforms.RandomRotation([-0.7, 0, 0.7])

        t = transforms.RandomRotation(10)
        angle = t.get_params(t.degrees)
1210
        self.assertTrue(angle > -10 and angle < 10)
1211
1212
1213

        t = transforms.RandomRotation((-10, 10))
        angle = t.get_params(t.degrees)
1214
        self.assertTrue(angle > -10 and angle < 10)
1215

1216
1217
1218
        # Checking if RandomRotation can be printed as string
        t.__repr__()

1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
    def test_random_affine(self):

        with self.assertRaises(ValueError):
            transforms.RandomAffine(-0.7)
            transforms.RandomAffine([-0.7])
            transforms.RandomAffine([-0.7, 0, 0.7])

            transforms.RandomAffine([-90, 90], translate=2.0)
            transforms.RandomAffine([-90, 90], translate=[-1.0, 1.0])
            transforms.RandomAffine([-90, 90], translate=[-1.0, 0.0, 1.0])

            transforms.RandomAffine([-90, 90], translate=[0.2, 0.2], scale=[0.0])
            transforms.RandomAffine([-90, 90], translate=[0.2, 0.2], scale=[-1.0, 1.0])
            transforms.RandomAffine([-90, 90], translate=[0.2, 0.2], scale=[0.5, -0.5])
            transforms.RandomAffine([-90, 90], translate=[0.2, 0.2], scale=[0.5, 3.0, -0.5])

            transforms.RandomAffine([-90, 90], translate=[0.2, 0.2], scale=[0.5, 0.5], shear=-7)
            transforms.RandomAffine([-90, 90], translate=[0.2, 0.2], scale=[0.5, 0.5], shear=[-10])
            transforms.RandomAffine([-90, 90], translate=[0.2, 0.2], scale=[0.5, 0.5], shear=[-10, 0, 10])
ptrblck's avatar
ptrblck committed
1238
            transforms.RandomAffine([-90, 90], translate=[0.2, 0.2], scale=[0.5, 0.5], shear=[-10, 0, 10, 0, 10])
1239
1240
1241
1242

        x = np.zeros((100, 100, 3), dtype=np.uint8)
        img = F.to_pil_image(x)

ptrblck's avatar
ptrblck committed
1243
        t = transforms.RandomAffine(10, translate=[0.5, 0.3], scale=[0.7, 1.3], shear=[-10, 10, 20, 40])
1244
1245
1246
        for _ in range(100):
            angle, translations, scale, shear = t.get_params(t.degrees, t.translate, t.scale, t.shear,
                                                             img_size=img.size)
1247
1248
1249
1250
1251
1252
1253
1254
            self.assertTrue(-10 < angle < 10)
            self.assertTrue(-img.size[0] * 0.5 <= translations[0] <= img.size[0] * 0.5,
                            "{} vs {}".format(translations[0], img.size[0] * 0.5))
            self.assertTrue(-img.size[1] * 0.5 <= translations[1] <= img.size[1] * 0.5,
                            "{} vs {}".format(translations[1], img.size[1] * 0.5))
            self.assertTrue(0.7 < scale < 1.3)
            self.assertTrue(-10 < shear[0] < 10)
            self.assertTrue(-20 < shear[1] < 40)
1255
1256
1257
1258
1259

        # Checking if RandomAffine can be printed as string
        t.__repr__()

        t = transforms.RandomAffine(10, resample=Image.BILINEAR)
1260
        self.assertIn("Image.BILINEAR", t.__repr__())
1261

1262
1263
1264
1265
1266
1267
1268
1269
1270
1271
1272
1273
1274
1275
1276
    def test_to_grayscale(self):
        """Unit tests for grayscale transform"""

        x_shape = [2, 2, 3]
        x_data = [0, 5, 13, 54, 135, 226, 37, 8, 234, 90, 255, 1]
        x_np = np.array(x_data, dtype=np.uint8).reshape(x_shape)
        x_pil = Image.fromarray(x_np, mode='RGB')
        x_pil_2 = x_pil.convert('L')
        gray_np = np.array(x_pil_2)

        # Test Set: Grayscale an image with desired number of output channels
        # Case 1: RGB -> 1 channel grayscale
        trans1 = transforms.Grayscale(num_output_channels=1)
        gray_pil_1 = trans1(x_pil)
        gray_np_1 = np.array(gray_pil_1)
1277
1278
        self.assertEqual(gray_pil_1.mode, 'L', 'mode should be L')
        self.assertEqual(gray_np_1.shape, tuple(x_shape[0:2]), 'should be 1 channel')
1279
1280
1281
1282
1283
1284
        np.testing.assert_equal(gray_np, gray_np_1)

        # Case 2: RGB -> 3 channel grayscale
        trans2 = transforms.Grayscale(num_output_channels=3)
        gray_pil_2 = trans2(x_pil)
        gray_np_2 = np.array(gray_pil_2)
1285
1286
        self.assertEqual(gray_pil_2.mode, 'RGB', 'mode should be RGB')
        self.assertEqual(gray_np_2.shape, tuple(x_shape), 'should be 3 channel')
1287
1288
1289
1290
1291
1292
1293
1294
        np.testing.assert_equal(gray_np_2[:, :, 0], gray_np_2[:, :, 1])
        np.testing.assert_equal(gray_np_2[:, :, 1], gray_np_2[:, :, 2])
        np.testing.assert_equal(gray_np, gray_np_2[:, :, 0])

        # Case 3: 1 channel grayscale -> 1 channel grayscale
        trans3 = transforms.Grayscale(num_output_channels=1)
        gray_pil_3 = trans3(x_pil_2)
        gray_np_3 = np.array(gray_pil_3)
1295
1296
        self.assertEqual(gray_pil_3.mode, 'L', 'mode should be L')
        self.assertEqual(gray_np_3.shape, tuple(x_shape[0:2]), 'should be 1 channel')
1297
1298
1299
1300
1301
1302
        np.testing.assert_equal(gray_np, gray_np_3)

        # Case 4: 1 channel grayscale -> 3 channel grayscale
        trans4 = transforms.Grayscale(num_output_channels=3)
        gray_pil_4 = trans4(x_pil_2)
        gray_np_4 = np.array(gray_pil_4)
1303
1304
        self.assertEqual(gray_pil_4.mode, 'RGB', 'mode should be RGB')
        self.assertEqual(gray_np_4.shape, tuple(x_shape), 'should be 3 channel')
1305
1306
1307
1308
        np.testing.assert_equal(gray_np_4[:, :, 0], gray_np_4[:, :, 1])
        np.testing.assert_equal(gray_np_4[:, :, 1], gray_np_4[:, :, 2])
        np.testing.assert_equal(gray_np, gray_np_4[:, :, 0])

1309
1310
1311
        # Checking if Grayscale can be printed as string
        trans4.__repr__()

1312
1313
1314
1315
1316
1317
1318
1319
1320
1321
1322
1323
1324
1325
1326
1327
1328
1329
1330
    @unittest.skipIf(stats is None, 'scipy.stats not available')
    def test_random_grayscale(self):
        """Unit tests for random grayscale transform"""

        # Test Set 1: RGB -> 3 channel grayscale
        random_state = random.getstate()
        random.seed(42)
        x_shape = [2, 2, 3]
        x_np = np.random.randint(0, 256, x_shape, np.uint8)
        x_pil = Image.fromarray(x_np, mode='RGB')
        x_pil_2 = x_pil.convert('L')
        gray_np = np.array(x_pil_2)

        num_samples = 250
        num_gray = 0
        for _ in range(num_samples):
            gray_pil_2 = transforms.RandomGrayscale(p=0.5)(x_pil)
            gray_np_2 = np.array(gray_pil_2)
            if np.array_equal(gray_np_2[:, :, 0], gray_np_2[:, :, 1]) and \
1331
1332
                    np.array_equal(gray_np_2[:, :, 1], gray_np_2[:, :, 2]) and \
                    np.array_equal(gray_np, gray_np_2[:, :, 0]):
1333
1334
1335
1336
                num_gray = num_gray + 1

        p_value = stats.binom_test(num_gray, num_samples, p=0.5)
        random.setstate(random_state)
1337
        self.assertGreater(p_value, 0.0001)
1338
1339
1340
1341
1342
1343
1344
1345
1346
1347
1348
1349
1350
1351
1352
1353
1354
1355
1356
1357

        # Test Set 2: grayscale -> 1 channel grayscale
        random_state = random.getstate()
        random.seed(42)
        x_shape = [2, 2, 3]
        x_np = np.random.randint(0, 256, x_shape, np.uint8)
        x_pil = Image.fromarray(x_np, mode='RGB')
        x_pil_2 = x_pil.convert('L')
        gray_np = np.array(x_pil_2)

        num_samples = 250
        num_gray = 0
        for _ in range(num_samples):
            gray_pil_3 = transforms.RandomGrayscale(p=0.5)(x_pil_2)
            gray_np_3 = np.array(gray_pil_3)
            if np.array_equal(gray_np, gray_np_3):
                num_gray = num_gray + 1

        p_value = stats.binom_test(num_gray, num_samples, p=1.0)  # Note: grayscale is always unchanged
        random.setstate(random_state)
1358
        self.assertGreater(p_value, 0.0001)
1359
1360
1361
1362
1363
1364
1365
1366
1367
1368
1369
1370
1371

        # Test set 3: Explicit tests
        x_shape = [2, 2, 3]
        x_data = [0, 5, 13, 54, 135, 226, 37, 8, 234, 90, 255, 1]
        x_np = np.array(x_data, dtype=np.uint8).reshape(x_shape)
        x_pil = Image.fromarray(x_np, mode='RGB')
        x_pil_2 = x_pil.convert('L')
        gray_np = np.array(x_pil_2)

        # Case 3a: RGB -> 3 channel grayscale (grayscaled)
        trans2 = transforms.RandomGrayscale(p=1.0)
        gray_pil_2 = trans2(x_pil)
        gray_np_2 = np.array(gray_pil_2)
1372
1373
        self.assertEqual(gray_pil_2.mode, 'RGB', 'mode should be RGB')
        self.assertEqual(gray_np_2.shape, tuple(x_shape), 'should be 3 channel')
1374
1375
1376
1377
1378
1379
1380
1381
        np.testing.assert_equal(gray_np_2[:, :, 0], gray_np_2[:, :, 1])
        np.testing.assert_equal(gray_np_2[:, :, 1], gray_np_2[:, :, 2])
        np.testing.assert_equal(gray_np, gray_np_2[:, :, 0])

        # Case 3b: RGB -> 3 channel grayscale (unchanged)
        trans2 = transforms.RandomGrayscale(p=0.0)
        gray_pil_2 = trans2(x_pil)
        gray_np_2 = np.array(gray_pil_2)
1382
1383
        self.assertEqual(gray_pil_2.mode, 'RGB', 'mode should be RGB')
        self.assertEqual(gray_np_2.shape, tuple(x_shape), 'should be 3 channel')
1384
1385
1386
1387
1388
1389
        np.testing.assert_equal(x_np, gray_np_2)

        # Case 3c: 1 channel grayscale -> 1 channel grayscale (grayscaled)
        trans3 = transforms.RandomGrayscale(p=1.0)
        gray_pil_3 = trans3(x_pil_2)
        gray_np_3 = np.array(gray_pil_3)
1390
1391
        self.assertEqual(gray_pil_3.mode, 'L', 'mode should be L')
        self.assertEqual(gray_np_3.shape, tuple(x_shape[0:2]), 'should be 1 channel')
1392
1393
1394
1395
1396
1397
        np.testing.assert_equal(gray_np, gray_np_3)

        # Case 3d: 1 channel grayscale -> 1 channel grayscale (unchanged)
        trans3 = transforms.RandomGrayscale(p=0.0)
        gray_pil_3 = trans3(x_pil_2)
        gray_np_3 = np.array(gray_pil_3)
1398
1399
        self.assertEqual(gray_pil_3.mode, 'L', 'mode should be L')
        self.assertEqual(gray_np_3.shape, tuple(x_shape[0:2]), 'should be 1 channel')
1400
1401
        np.testing.assert_equal(gray_np, gray_np_3)

1402
1403
1404
        # Checking if RandomGrayscale can be printed as string
        trans3.__repr__()

1405
1406
1407
    def test_random_erasing(self):
        """Unit tests for random erasing transform"""

1408
        img = torch.rand([3, 60, 60])
1409
1410

        # Test Set 1: Erasing with int value
1411
1412
1413
        img_re = transforms.RandomErasing(value=0.2)
        i, j, h, w, v = img_re.get_params(img, scale=img_re.scale, ratio=img_re.ratio, value=img_re.value)
        img_output = F.erase(img, i, j, h, w, v)
1414
        self.assertEqual(img_output.size(0), 3)
1415
1416
1417
1418
1419
1420

        # Test Set 2: Check if the unerased region is preserved
        orig_unerased = img.clone()
        orig_unerased[:, i:i + h, j:j + w] = 0
        output_unerased = img_output.clone()
        output_unerased[:, i:i + h, j:j + w] = 0
1421
        self.assertTrue(torch.equal(orig_unerased, output_unerased))
1422
1423

        # Test Set 3: Erasing with random value
1424
        img_re = transforms.RandomErasing(value='random')(img)
1425
        self.assertEqual(img_re.size(0), 3)
1426

1427
        # Test Set 4: Erasing with tuple value
1428
        img_re = transforms.RandomErasing(value=(0.2, 0.2, 0.2))(img)
1429
        self.assertEqual(img_re.size(0), 3)
1430

1431
1432
        # Test Set 5: Testing the inplace behaviour
        img_re = transforms.RandomErasing(value=(0.2), inplace=True)(img)
1433
        self.assertTrue(torch.equal(img_re, img))
1434

Zhun Zhong's avatar
Zhun Zhong committed
1435
1436
1437
        # Test Set 6: Checking when no erased region is selected
        img = torch.rand([3, 300, 1])
        img_re = transforms.RandomErasing(ratio=(0.1, 0.2), value='random')(img)
1438
        self.assertTrue(torch.equal(img_re, img))
Zhun Zhong's avatar
Zhun Zhong committed
1439

1440

1441
1442
if __name__ == '__main__':
    unittest.main()