test_transforms.py 63.3 KB
Newer Older
1
import os
2
3
import torch
import torchvision.transforms as transforms
4
import torchvision.transforms.functional as F
5
from torch._utils_internal import get_file_path_2
6
from numpy.testing import assert_array_almost_equal
7
import unittest
8
import math
9
import random
10
import numpy as np
11
12
13
14
15
16
from PIL import Image
try:
    import accimage
except ImportError:
    accimage = None

17
18
19
20
21
try:
    from scipy import stats
except ImportError:
    stats = None

22
23
GRACE_HOPPER = get_file_path_2(
    os.path.dirname(os.path.abspath(__file__)), 'assets', 'grace_hopper_517x606.jpg')
24

25

26
class Tester(unittest.TestCase):
27

28
29
30
31
    def test_crop(self):
        height = random.randint(10, 32) * 2
        width = random.randint(10, 32) * 2
        oheight = random.randint(5, (height - 2) / 2) * 2
32
33
        owidth = random.randint(5, (width - 2) / 2) * 2

34
        img = torch.ones(3, height, width)
35
36
37
        oh1 = (height - oheight) // 2
        ow1 = (width - owidth) // 2
        imgnarrow = img[:, oh1:oh1 + oheight, ow1:ow1 + owidth]
38
39
40
41
42
43
        imgnarrow.fill_(0)
        result = transforms.Compose([
            transforms.ToPILImage(),
            transforms.CenterCrop((oheight, owidth)),
            transforms.ToTensor(),
        ])(img)
44
45
        self.assertEqual(result.sum(), 0,
                         "height: {} width: {} oheight: {} owdith: {}".format(height, width, oheight, owidth))
46
47
48
49
50
51
52
53
        oheight += 1
        owidth += 1
        result = transforms.Compose([
            transforms.ToPILImage(),
            transforms.CenterCrop((oheight, owidth)),
            transforms.ToTensor(),
        ])(img)
        sum1 = result.sum()
54
55
        self.assertGreater(sum1, 1,
                           "height: {} width: {} oheight: {} owdith: {}".format(height, width, oheight, owidth))
56
        oheight += 1
57
        owidth += 1
58
59
60
61
62
63
        result = transforms.Compose([
            transforms.ToPILImage(),
            transforms.CenterCrop((oheight, owidth)),
            transforms.ToTensor(),
        ])(img)
        sum2 = result.sum()
64
65
66
67
        self.assertGreater(sum2, 0,
                           "height: {} width: {} oheight: {} owdith: {}".format(height, width, oheight, owidth))
        self.assertGreater(sum2, sum1,
                           "height: {} width: {} oheight: {} owdith: {}".format(height, width, oheight, owidth))
68

69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
    def test_five_crop(self):
        to_pil_image = transforms.ToPILImage()
        h = random.randint(5, 25)
        w = random.randint(5, 25)
        for single_dim in [True, False]:
            crop_h = random.randint(1, h)
            crop_w = random.randint(1, w)
            if single_dim:
                crop_h = min(crop_h, crop_w)
                crop_w = crop_h
                transform = transforms.FiveCrop(crop_h)
            else:
                transform = transforms.FiveCrop((crop_h, crop_w))

            img = torch.FloatTensor(3, h, w).uniform_()
            results = transform(to_pil_image(img))

86
            self.assertEqual(len(results), 5)
87
            for crop in results:
88
                self.assertEqual(crop.size, (crop_w, crop_h))
89
90
91
92
93
94
95
96

            to_pil_image = transforms.ToPILImage()
            tl = to_pil_image(img[:, 0:crop_h, 0:crop_w])
            tr = to_pil_image(img[:, 0:crop_h, w - crop_w:])
            bl = to_pil_image(img[:, h - crop_h:, 0:crop_w])
            br = to_pil_image(img[:, h - crop_h:, w - crop_w:])
            center = transforms.CenterCrop((crop_h, crop_w))(to_pil_image(img))
            expected_output = (tl, tr, bl, br, center)
97
            self.assertEqual(results, expected_output)
98
99
100
101
102
103
104
105
106
107
108
109

    def test_ten_crop(self):
        to_pil_image = transforms.ToPILImage()
        h = random.randint(5, 25)
        w = random.randint(5, 25)
        for should_vflip in [True, False]:
            for single_dim in [True, False]:
                crop_h = random.randint(1, h)
                crop_w = random.randint(1, w)
                if single_dim:
                    crop_h = min(crop_h, crop_w)
                    crop_w = crop_h
110
111
                    transform = transforms.TenCrop(crop_h,
                                                   vertical_flip=should_vflip)
112
113
                    five_crop = transforms.FiveCrop(crop_h)
                else:
114
115
                    transform = transforms.TenCrop((crop_h, crop_w),
                                                   vertical_flip=should_vflip)
116
117
118
119
120
                    five_crop = transforms.FiveCrop((crop_h, crop_w))

                img = to_pil_image(torch.FloatTensor(3, h, w).uniform_())
                results = transform(img)
                expected_output = five_crop(img)
121
122
123
124
125

                # Checking if FiveCrop and TenCrop can be printed as string
                transform.__repr__()
                five_crop.__repr__()

126
127
128
129
130
131
132
                if should_vflip:
                    vflipped_img = img.transpose(Image.FLIP_TOP_BOTTOM)
                    expected_output += five_crop(vflipped_img)
                else:
                    hflipped_img = img.transpose(Image.FLIP_LEFT_RIGHT)
                    expected_output += five_crop(hflipped_img)

133
134
                self.assertEqual(len(results), 10)
                self.assertEqual(results, expected_output)
135

136
137
138
139
140
141
142
143
    def test_randomresized_params(self):
        height = random.randint(24, 32) * 2
        width = random.randint(24, 32) * 2
        img = torch.ones(3, height, width)
        to_pil_image = transforms.ToPILImage()
        img = to_pil_image(img)
        size = 100
        epsilon = 0.05
144
        min_scale = 0.25
Francisco Massa's avatar
Francisco Massa committed
145
        for _ in range(10):
146
            scale_min = max(round(random.random(), 2), min_scale)
147
            scale_range = (scale_min, scale_min + round(random.random(), 2))
148
            aspect_min = max(round(random.random(), 2), epsilon)
149
150
            aspect_ratio_range = (aspect_min, aspect_min + round(random.random(), 2))
            randresizecrop = transforms.RandomResizedCrop(size, scale_range, aspect_ratio_range)
151
            i, j, h, w = randresizecrop.get_params(img, scale_range, aspect_ratio_range)
152
            aspect_ratio_obtained = w / h
153
154
155
156
157
158
159
            self.assertTrue((min(aspect_ratio_range) - epsilon <= aspect_ratio_obtained and
                             aspect_ratio_obtained <= max(aspect_ratio_range) + epsilon) or
                            aspect_ratio_obtained == 1.0)
            self.assertIsInstance(i, int)
            self.assertIsInstance(j, int)
            self.assertIsInstance(h, int)
            self.assertIsInstance(w, int)
160

161
    def test_randomperspective(self):
Francisco Massa's avatar
Francisco Massa committed
162
        for _ in range(10):
163
164
165
166
167
168
169
170
171
172
            height = random.randint(24, 32) * 2
            width = random.randint(24, 32) * 2
            img = torch.ones(3, height, width)
            to_pil_image = transforms.ToPILImage()
            img = to_pil_image(img)
            perp = transforms.RandomPerspective()
            startpoints, endpoints = perp.get_params(width, height, 0.5)
            tr_img = F.perspective(img, startpoints, endpoints)
            tr_img2 = F.to_tensor(F.perspective(tr_img, endpoints, startpoints))
            tr_img = F.to_tensor(tr_img)
173
174
175
176
            self.assertEqual(img.size[0], width)
            self.assertEqual(img.size[1], height)
            self.assertGreater(torch.nn.functional.mse_loss(tr_img, F.to_tensor(img)) + 0.3,
                               torch.nn.functional.mse_loss(tr_img2, F.to_tensor(img)))
177

178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
    def test_randomperspective_fill(self):
        height = 100
        width = 100
        img = torch.ones(3, height, width)
        to_pil_image = transforms.ToPILImage()
        img = to_pil_image(img)

        modes = ("L", "RGB", "F")
        nums_bands = [len(mode) for mode in modes]
        fill = 127

        for mode, num_bands in zip(modes, nums_bands):
            img_conv = img.convert(mode)
            perspective = transforms.RandomPerspective(p=1, fill=fill)
            tr_img = perspective(img_conv)
            pixel = tr_img.getpixel((0, 0))

            if not isinstance(pixel, tuple):
                pixel = (pixel,)
            self.assertTupleEqual(pixel, tuple([fill] * num_bands))

        for mode, num_bands in zip(modes, nums_bands):
            img_conv = img.convert(mode)
            startpoints, endpoints = transforms.RandomPerspective.get_params(width, height, 0.5)
            tr_img = F.perspective(img_conv, startpoints, endpoints, fill=fill)
            pixel = tr_img.getpixel((0, 0))
204

205
206
207
208
209
210
211
212
            if not isinstance(pixel, tuple):
                pixel = (pixel,)
            self.assertTupleEqual(pixel, tuple([fill] * num_bands))

            for wrong_num_bands in set(nums_bands) - {num_bands}:
                with self.assertRaises(ValueError):
                    F.perspective(img_conv, startpoints, endpoints, fill=tuple([fill] * wrong_num_bands))

213
    def test_resize(self):
214
215
216
        height = random.randint(24, 32) * 2
        width = random.randint(24, 32) * 2
        osize = random.randint(5, 12) * 2
217

218
219
220
        img = torch.ones(3, height, width)
        result = transforms.Compose([
            transforms.ToPILImage(),
221
            transforms.Resize(osize),
222
223
            transforms.ToTensor(),
        ])(img)
224
        self.assertIn(osize, result.size())
225
        if height < width:
226
            self.assertLessEqual(result.size(1), result.size(2))
227
        elif width < height:
228
            self.assertGreaterEqual(result.size(1), result.size(2))
229

230
231
        result = transforms.Compose([
            transforms.ToPILImage(),
232
            transforms.Resize([osize, osize]),
233
234
            transforms.ToTensor(),
        ])(img)
235
236
237
        self.assertIn(osize, result.size())
        self.assertEqual(result.size(1), osize)
        self.assertEqual(result.size(2), osize)
238

239
240
241
242
        oheight = random.randint(5, 12) * 2
        owidth = random.randint(5, 12) * 2
        result = transforms.Compose([
            transforms.ToPILImage(),
243
            transforms.Resize((oheight, owidth)),
244
245
            transforms.ToTensor(),
        ])(img)
246
247
        self.assertEqual(result.size(1), oheight)
        self.assertEqual(result.size(2), owidth)
248
249
250

        result = transforms.Compose([
            transforms.ToPILImage(),
251
            transforms.Resize([oheight, owidth]),
252
253
            transforms.ToTensor(),
        ])(img)
254
255
        self.assertEqual(result.size(1), oheight)
        self.assertEqual(result.size(2), owidth)
256

257
258
259
260
    def test_random_crop(self):
        height = random.randint(10, 32) * 2
        width = random.randint(10, 32) * 2
        oheight = random.randint(5, (height - 2) / 2) * 2
261
        owidth = random.randint(5, (width - 2) / 2) * 2
262
263
264
265
266
267
        img = torch.ones(3, height, width)
        result = transforms.Compose([
            transforms.ToPILImage(),
            transforms.RandomCrop((oheight, owidth)),
            transforms.ToTensor(),
        ])(img)
268
269
        self.assertEqual(result.size(1), oheight)
        self.assertEqual(result.size(2), owidth)
270

271
272
273
274
275
276
        padding = random.randint(1, 20)
        result = transforms.Compose([
            transforms.ToPILImage(),
            transforms.RandomCrop((oheight, owidth), padding=padding),
            transforms.ToTensor(),
        ])(img)
277
278
        self.assertEqual(result.size(1), oheight)
        self.assertEqual(result.size(2), owidth)
279

280
281
282
283
284
        result = transforms.Compose([
            transforms.ToPILImage(),
            transforms.RandomCrop((height, width)),
            transforms.ToTensor()
        ])(img)
285
286
287
        self.assertEqual(result.size(1), height)
        self.assertEqual(result.size(2), width)
        self.assertTrue(np.allclose(img.numpy(), result.numpy()))
288

289
290
291
292
293
        result = transforms.Compose([
            transforms.ToPILImage(),
            transforms.RandomCrop((height + 1, width + 1), pad_if_needed=True),
            transforms.ToTensor(),
        ])(img)
294
295
        self.assertEqual(result.size(1), height + 1)
        self.assertEqual(result.size(2), width + 1)
296

297
298
299
300
301
302
303
304
305
306
    def test_pad(self):
        height = random.randint(10, 32) * 2
        width = random.randint(10, 32) * 2
        img = torch.ones(3, height, width)
        padding = random.randint(1, 20)
        result = transforms.Compose([
            transforms.ToPILImage(),
            transforms.Pad(padding),
            transforms.ToTensor(),
        ])(img)
307
308
        self.assertEqual(result.size(1), height + 2 * padding)
        self.assertEqual(result.size(2), width + 2 * padding)
Soumith Chintala's avatar
Soumith Chintala committed
309

310
311
312
313
314
315
316
    def test_pad_with_tuple_of_pad_values(self):
        height = random.randint(10, 32) * 2
        width = random.randint(10, 32) * 2
        img = transforms.ToPILImage()(torch.ones(3, height, width))

        padding = tuple([random.randint(1, 20) for _ in range(2)])
        output = transforms.Pad(padding)(img)
317
        self.assertEqual(output.size, (width + padding[0] * 2, height + padding[1] * 2))
318
319
320

        padding = tuple([random.randint(1, 20) for _ in range(4)])
        output = transforms.Pad(padding)(img)
321
322
        self.assertEqual(output.size[0], width + padding[0] + padding[2])
        self.assertEqual(output.size[1], height + padding[1] + padding[3])
323

324
325
326
        # Checking if Padding can be printed as string
        transforms.Pad(padding).__repr__()

327
328
    def test_pad_with_non_constant_padding_modes(self):
        """Unit tests for edge, reflect, symmetric padding"""
vfdev's avatar
vfdev committed
329
        img = torch.zeros(3, 27, 27).byte()
330
331
332
333
334
335
336
337
338
        img[:, :, 0] = 1  # Constant value added to leftmost edge
        img = transforms.ToPILImage()(img)
        img = F.pad(img, 1, (200, 200, 200))

        # pad 3 to all sidess
        edge_padded_img = F.pad(img, 3, padding_mode='edge')
        # First 6 elements of leftmost edge in the middle of the image, values are in order:
        # edge_pad, edge_pad, edge_pad, constant_pad, constant value added to leftmost edge, 0
        edge_middle_slice = np.asarray(edge_padded_img).transpose(2, 0, 1)[0][17][:6]
339
340
        self.assertTrue(np.all(edge_middle_slice == np.asarray([200, 200, 200, 200, 1, 0])))
        self.assertEqual(transforms.ToTensor()(edge_padded_img).size(), (3, 35, 35))
341
342
343
344
345
346

        # Pad 3 to left/right, 2 to top/bottom
        reflect_padded_img = F.pad(img, (3, 2), padding_mode='reflect')
        # First 6 elements of leftmost edge in the middle of the image, values are in order:
        # reflect_pad, reflect_pad, reflect_pad, constant_pad, constant value added to leftmost edge, 0
        reflect_middle_slice = np.asarray(reflect_padded_img).transpose(2, 0, 1)[0][17][:6]
347
348
        self.assertTrue(np.all(reflect_middle_slice == np.asarray([0, 0, 1, 200, 1, 0])))
        self.assertEqual(transforms.ToTensor()(reflect_padded_img).size(), (3, 33, 35))
349
350
351
352
353
354

        # Pad 3 to left, 2 to top, 2 to right, 1 to bottom
        symmetric_padded_img = F.pad(img, (3, 2, 2, 1), padding_mode='symmetric')
        # First 6 elements of leftmost edge in the middle of the image, values are in order:
        # sym_pad, sym_pad, sym_pad, constant_pad, constant value added to leftmost edge, 0
        symmetric_middle_slice = np.asarray(symmetric_padded_img).transpose(2, 0, 1)[0][17][:6]
355
356
        self.assertTrue(np.all(symmetric_middle_slice == np.asarray([0, 1, 200, 200, 1, 0])))
        self.assertEqual(transforms.ToTensor()(symmetric_padded_img).size(), (3, 32, 34))
357

358
    def test_pad_raises_with_invalid_pad_sequence_len(self):
359
360
361
362
363
364
365
366
367
        with self.assertRaises(ValueError):
            transforms.Pad(())

        with self.assertRaises(ValueError):
            transforms.Pad((1, 2, 3))

        with self.assertRaises(ValueError):
            transforms.Pad((1, 2, 3, 4, 5))

Soumith Chintala's avatar
Soumith Chintala committed
368
369
370
371
    def test_lambda(self):
        trans = transforms.Lambda(lambda x: x.add(10))
        x = torch.randn(10)
        y = trans(x)
372
        self.assertTrue(y.equal(torch.add(x, 10)))
Soumith Chintala's avatar
Soumith Chintala committed
373
374
375
376

        trans = transforms.Lambda(lambda x: x.add_(10))
        x = torch.randn(10)
        y = trans(x)
377
        self.assertTrue(y.equal(x))
378

379
380
381
        # Checking if Lambda can be printed as string
        trans.__repr__()

382
    @unittest.skipIf(stats is None, 'scipy.stats not available')
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
    def test_random_apply(self):
        random_state = random.getstate()
        random.seed(42)
        random_apply_transform = transforms.RandomApply(
            [
                transforms.RandomRotation((-45, 45)),
                transforms.RandomHorizontalFlip(),
                transforms.RandomVerticalFlip(),
            ], p=0.75
        )
        img = transforms.ToPILImage()(torch.rand(3, 10, 10))
        num_samples = 250
        num_applies = 0
        for _ in range(num_samples):
            out = random_apply_transform(img)
            if out != img:
                num_applies += 1

        p_value = stats.binom_test(num_applies, num_samples, p=0.75)
        random.setstate(random_state)
403
        self.assertGreater(p_value, 0.0001)
404
405
406
407

        # Checking if RandomApply can be printed as string
        random_apply_transform.__repr__()

408
    @unittest.skipIf(stats is None, 'scipy.stats not available')
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
    def test_random_choice(self):
        random_state = random.getstate()
        random.seed(42)
        random_choice_transform = transforms.RandomChoice(
            [
                transforms.Resize(15),
                transforms.Resize(20),
                transforms.CenterCrop(10)
            ]
        )
        img = transforms.ToPILImage()(torch.rand(3, 25, 25))
        num_samples = 250
        num_resize_15 = 0
        num_resize_20 = 0
        num_crop_10 = 0
        for _ in range(num_samples):
            out = random_choice_transform(img)
            if out.size == (15, 15):
                num_resize_15 += 1
            elif out.size == (20, 20):
                num_resize_20 += 1
            elif out.size == (10, 10):
                num_crop_10 += 1

        p_value = stats.binom_test(num_resize_15, num_samples, p=0.33333)
434
        self.assertGreater(p_value, 0.0001)
435
        p_value = stats.binom_test(num_resize_20, num_samples, p=0.33333)
436
        self.assertGreater(p_value, 0.0001)
437
        p_value = stats.binom_test(num_crop_10, num_samples, p=0.33333)
438
        self.assertGreater(p_value, 0.0001)
439
440
441
442
443

        random.setstate(random_state)
        # Checking if RandomChoice can be printed as string
        random_choice_transform.__repr__()

444
    @unittest.skipIf(stats is None, 'scipy.stats not available')
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
    def test_random_order(self):
        random_state = random.getstate()
        random.seed(42)
        random_order_transform = transforms.RandomOrder(
            [
                transforms.Resize(20),
                transforms.CenterCrop(10)
            ]
        )
        img = transforms.ToPILImage()(torch.rand(3, 25, 25))
        num_samples = 250
        num_normal_order = 0
        resize_crop_out = transforms.CenterCrop(10)(transforms.Resize(20)(img))
        for _ in range(num_samples):
            out = random_order_transform(img)
            if out == resize_crop_out:
                num_normal_order += 1

        p_value = stats.binom_test(num_normal_order, num_samples, p=0.5)
        random.setstate(random_state)
465
        self.assertGreater(p_value, 0.0001)
466
467
468
469

        # Checking if RandomOrder can be printed as string
        random_order_transform.__repr__()

470
    def test_to_tensor(self):
471
        test_channels = [1, 3, 4]
472
473
        height, width = 4, 4
        trans = transforms.ToTensor()
474

475
476
477
478
479
480
481
        with self.assertRaises(TypeError):
            trans(np.random.rand(1, height, width).tolist())

        with self.assertRaises(ValueError):
            trans(np.random.rand(height))
            trans(np.random.rand(1, 1, height, width))

482
483
484
485
        for channels in test_channels:
            input_data = torch.ByteTensor(channels, height, width).random_(0, 255).float().div_(255)
            img = transforms.ToPILImage()(input_data)
            output = trans(img)
486
            self.assertTrue(np.allclose(input_data.numpy(), output.numpy()))
487

488
            ndarray = np.random.randint(low=0, high=255, size=(height, width, channels)).astype(np.uint8)
489
490
            output = trans(ndarray)
            expected_output = ndarray.transpose((2, 0, 1)) / 255.0
491
            self.assertTrue(np.allclose(output.numpy(), expected_output))
492

493
494
495
            ndarray = np.random.rand(height, width, channels).astype(np.float32)
            output = trans(ndarray)
            expected_output = ndarray.transpose((2, 0, 1))
496
            self.assertTrue(np.allclose(output.numpy(), expected_output))
497

498
499
500
501
        # separate test for mode '1' PIL images
        input_data = torch.ByteTensor(1, height, width).bernoulli_()
        img = transforms.ToPILImage()(input_data.mul(255)).convert('1')
        output = trans(img)
502
        self.assertTrue(np.allclose(input_data.numpy(), output.numpy()))
503

504
505
506
507
508
509
510
511
    @unittest.skipIf(accimage is None, 'accimage not available')
    def test_accimage_to_tensor(self):
        trans = transforms.ToTensor()

        expected_output = trans(Image.open(GRACE_HOPPER).convert('RGB'))
        output = trans(accimage.Image(GRACE_HOPPER))

        self.assertEqual(expected_output.size(), output.size())
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
        self.assertTrue(np.allclose(output.numpy(), expected_output.numpy()))

    def test_pil_to_tensor(self):
        test_channels = [1, 3, 4]
        height, width = 4, 4
        trans = transforms.PILToTensor()

        with self.assertRaises(TypeError):
            trans(np.random.rand(1, height, width).tolist())
            trans(np.random.rand(1, height, width))

        for channels in test_channels:
            input_data = torch.ByteTensor(channels, height, width).random_(0, 255)
            img = transforms.ToPILImage()(input_data)
            output = trans(img)
            self.assertTrue(np.allclose(input_data.numpy(), output.numpy()))

            input_data = np.random.randint(low=0, high=255, size=(height, width, channels)).astype(np.uint8)
            img = transforms.ToPILImage()(input_data)
            output = trans(img)
            expected_output = input_data.transpose((2, 0, 1))
            self.assertTrue(np.allclose(output.numpy(), expected_output))

            input_data = torch.as_tensor(np.random.rand(channels, height, width).astype(np.float32))
            img = transforms.ToPILImage()(input_data)  # CHW -> HWC and (* 255).byte()
            output = trans(img)  # HWC -> CHW
            expected_output = (input_data * 255).byte()
            self.assertTrue(np.allclose(output.numpy(), expected_output.numpy()))

        # separate test for mode '1' PIL images
        input_data = torch.ByteTensor(1, height, width).bernoulli_()
        img = transforms.ToPILImage()(input_data.mul(255)).convert('1')
        output = trans(img)
        self.assertTrue(np.allclose(input_data.numpy(), output.numpy()))

    @unittest.skipIf(accimage is None, 'accimage not available')
    def test_accimage_pil_to_tensor(self):
        trans = transforms.PILToTensor()

        expected_output = trans(Image.open(GRACE_HOPPER).convert('RGB'))
        output = trans(accimage.Image(GRACE_HOPPER))

        self.assertEqual(expected_output.size(), output.size())
555
        self.assertTrue(np.allclose(output.numpy(), expected_output.numpy()))
556
557
558
559

    @unittest.skipIf(accimage is None, 'accimage not available')
    def test_accimage_resize(self):
        trans = transforms.Compose([
560
            transforms.Resize(256, interpolation=Image.LINEAR),
561
562
563
            transforms.ToTensor(),
        ])

564
565
566
        # Checking if Compose, Resize and ToTensor can be printed as string
        trans.__repr__()

567
568
569
570
571
572
573
        expected_output = trans(Image.open(GRACE_HOPPER).convert('RGB'))
        output = trans(accimage.Image(GRACE_HOPPER))

        self.assertEqual(expected_output.size(), output.size())
        self.assertLess(np.abs((expected_output - output).mean()), 1e-3)
        self.assertLess((expected_output - output).var(), 1e-5)
        # note the high absolute tolerance
574
        self.assertTrue(np.allclose(output.numpy(), expected_output.numpy(), atol=5e-2))
575
576
577
578
579
580
581
582

    @unittest.skipIf(accimage is None, 'accimage not available')
    def test_accimage_crop(self):
        trans = transforms.Compose([
            transforms.CenterCrop(256),
            transforms.ToTensor(),
        ])

583
584
585
        # Checking if Compose, CenterCrop and ToTensor can be printed as string
        trans.__repr__()

586
587
588
589
        expected_output = trans(Image.open(GRACE_HOPPER).convert('RGB'))
        output = trans(accimage.Image(GRACE_HOPPER))

        self.assertEqual(expected_output.size(), output.size())
590
        self.assertTrue(np.allclose(output.numpy(), expected_output.numpy()))
591

592
    def test_1_channel_tensor_to_pil_image(self):
593
594
        to_tensor = transforms.ToTensor()

595
        img_data_float = torch.Tensor(1, 4, 4).uniform_()
596
597
598
599
        img_data_byte = torch.ByteTensor(1, 4, 4).random_(0, 255)
        img_data_short = torch.ShortTensor(1, 4, 4).random_()
        img_data_int = torch.IntTensor(1, 4, 4).random_()

600
601
602
603
604
605
606
607
608
609
        inputs = [img_data_float, img_data_byte, img_data_short, img_data_int]
        expected_outputs = [img_data_float.mul(255).int().float().div(255).numpy(),
                            img_data_byte.float().div(255.0).numpy(),
                            img_data_short.numpy(),
                            img_data_int.numpy()]
        expected_modes = ['L', 'L', 'I;16', 'I']

        for img_data, expected_output, mode in zip(inputs, expected_outputs, expected_modes):
            for transform in [transforms.ToPILImage(), transforms.ToPILImage(mode=mode)]:
                img = transform(img_data)
610
611
                self.assertEqual(img.mode, mode)
                self.assertTrue(np.allclose(expected_output, to_tensor(img).numpy()))
612
613
        # 'F' mode for torch.FloatTensor
        img_F_mode = transforms.ToPILImage(mode='F')(img_data_float)
614
615
616
        self.assertEqual(img_F_mode.mode, 'F')
        self.assertTrue(np.allclose(np.array(Image.fromarray(img_data_float.squeeze(0).numpy(), mode='F')),
                                    np.array(img_F_mode)))
617
618
619
620
621
622
623
624
625
626
627
628

    def test_1_channel_ndarray_to_pil_image(self):
        img_data_float = torch.Tensor(4, 4, 1).uniform_().numpy()
        img_data_byte = torch.ByteTensor(4, 4, 1).random_(0, 255).numpy()
        img_data_short = torch.ShortTensor(4, 4, 1).random_().numpy()
        img_data_int = torch.IntTensor(4, 4, 1).random_().numpy()

        inputs = [img_data_float, img_data_byte, img_data_short, img_data_int]
        expected_modes = ['F', 'L', 'I;16', 'I']
        for img_data, mode in zip(inputs, expected_modes):
            for transform in [transforms.ToPILImage(), transforms.ToPILImage(mode=mode)]:
                img = transform(img_data)
629
630
                self.assertEqual(img.mode, mode)
                self.assertTrue(np.allclose(img_data[:, :, 0], img))
631

surgan12's avatar
surgan12 committed
632
633
634
635
    def test_2_channel_ndarray_to_pil_image(self):
        def verify_img_data(img_data, mode):
            if mode is None:
                img = transforms.ToPILImage()(img_data)
636
                self.assertEqual(img.mode, 'LA')  # default should assume LA
surgan12's avatar
surgan12 committed
637
638
            else:
                img = transforms.ToPILImage(mode=mode)(img_data)
639
                self.assertEqual(img.mode, mode)
surgan12's avatar
surgan12 committed
640
641
            split = img.split()
            for i in range(2):
642
                self.assertTrue(np.allclose(img_data[:, :, i], split[i]))
surgan12's avatar
surgan12 committed
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659

        img_data = torch.ByteTensor(4, 4, 2).random_(0, 255).numpy()
        for mode in [None, 'LA']:
            verify_img_data(img_data, mode)

        transforms.ToPILImage().__repr__()

        with self.assertRaises(ValueError):
            # should raise if we try a mode for 4 or 1 or 3 channel images
            transforms.ToPILImage(mode='RGBA')(img_data)
            transforms.ToPILImage(mode='P')(img_data)
            transforms.ToPILImage(mode='RGB')(img_data)

    def test_2_channel_tensor_to_pil_image(self):
        def verify_img_data(img_data, expected_output, mode):
            if mode is None:
                img = transforms.ToPILImage()(img_data)
660
                self.assertEqual(img.mode, 'LA')  # default should assume LA
surgan12's avatar
surgan12 committed
661
662
            else:
                img = transforms.ToPILImage(mode=mode)(img_data)
663
                self.assertEqual(img.mode, mode)
surgan12's avatar
surgan12 committed
664
665
            split = img.split()
            for i in range(2):
666
                self.assertTrue(np.allclose(expected_output[i].numpy(), F.to_tensor(split[i]).numpy()))
surgan12's avatar
surgan12 committed
667
668
669
670
671
672
673
674
675
676
677
678

        img_data = torch.Tensor(2, 4, 4).uniform_()
        expected_output = img_data.mul(255).int().float().div(255)
        for mode in [None, 'LA']:
            verify_img_data(img_data, expected_output, mode=mode)

        with self.assertRaises(ValueError):
            # should raise if we try a mode for 4 or 1 or 3 channel images
            transforms.ToPILImage(mode='RGBA')(img_data)
            transforms.ToPILImage(mode='P')(img_data)
            transforms.ToPILImage(mode='RGB')(img_data)

679
680
681
682
    def test_3_channel_tensor_to_pil_image(self):
        def verify_img_data(img_data, expected_output, mode):
            if mode is None:
                img = transforms.ToPILImage()(img_data)
683
                self.assertEqual(img.mode, 'RGB')  # default should assume RGB
684
685
            else:
                img = transforms.ToPILImage(mode=mode)(img_data)
686
                self.assertEqual(img.mode, mode)
687
688
            split = img.split()
            for i in range(3):
689
                self.assertTrue(np.allclose(expected_output[i].numpy(), F.to_tensor(split[i]).numpy()))
690

691
692
693
694
        img_data = torch.Tensor(3, 4, 4).uniform_()
        expected_output = img_data.mul(255).int().float().div(255)
        for mode in [None, 'RGB', 'HSV', 'YCbCr']:
            verify_img_data(img_data, expected_output, mode=mode)
695

696
        with self.assertRaises(ValueError):
surgan12's avatar
surgan12 committed
697
            # should raise if we try a mode for 4 or 1 or 2 channel images
698
699
            transforms.ToPILImage(mode='RGBA')(img_data)
            transforms.ToPILImage(mode='P')(img_data)
surgan12's avatar
surgan12 committed
700
            transforms.ToPILImage(mode='LA')(img_data)
701

Varun Agrawal's avatar
Varun Agrawal committed
702
703
704
        with self.assertRaises(ValueError):
            transforms.ToPILImage()(torch.Tensor(1, 3, 4, 4).uniform_())

705
706
707
708
    def test_3_channel_ndarray_to_pil_image(self):
        def verify_img_data(img_data, mode):
            if mode is None:
                img = transforms.ToPILImage()(img_data)
709
                self.assertEqual(img.mode, 'RGB')  # default should assume RGB
710
711
            else:
                img = transforms.ToPILImage(mode=mode)(img_data)
712
                self.assertEqual(img.mode, mode)
713
714
            split = img.split()
            for i in range(3):
715
                self.assertTrue(np.allclose(img_data[:, :, i], split[i]))
716

717
718
719
720
        img_data = torch.ByteTensor(4, 4, 3).random_(0, 255).numpy()
        for mode in [None, 'RGB', 'HSV', 'YCbCr']:
            verify_img_data(img_data, mode)

721
722
723
        # Checking if ToPILImage can be printed as string
        transforms.ToPILImage().__repr__()

724
        with self.assertRaises(ValueError):
surgan12's avatar
surgan12 committed
725
            # should raise if we try a mode for 4 or 1 or 2 channel images
726
727
            transforms.ToPILImage(mode='RGBA')(img_data)
            transforms.ToPILImage(mode='P')(img_data)
surgan12's avatar
surgan12 committed
728
            transforms.ToPILImage(mode='LA')(img_data)
729
730
731
732
733

    def test_4_channel_tensor_to_pil_image(self):
        def verify_img_data(img_data, expected_output, mode):
            if mode is None:
                img = transforms.ToPILImage()(img_data)
734
                self.assertEqual(img.mode, 'RGBA')  # default should assume RGBA
735
736
            else:
                img = transforms.ToPILImage(mode=mode)(img_data)
737
                self.assertEqual(img.mode, mode)
738
739
740

            split = img.split()
            for i in range(4):
741
                self.assertTrue(np.allclose(expected_output[i].numpy(), F.to_tensor(split[i]).numpy()))
742

743
        img_data = torch.Tensor(4, 4, 4).uniform_()
744
        expected_output = img_data.mul(255).int().float().div(255)
surgan12's avatar
surgan12 committed
745
        for mode in [None, 'RGBA', 'CMYK', 'RGBX']:
746
            verify_img_data(img_data, expected_output, mode)
747

748
        with self.assertRaises(ValueError):
surgan12's avatar
surgan12 committed
749
            # should raise if we try a mode for 3 or 1 or 2 channel images
750
751
            transforms.ToPILImage(mode='RGB')(img_data)
            transforms.ToPILImage(mode='P')(img_data)
surgan12's avatar
surgan12 committed
752
            transforms.ToPILImage(mode='LA')(img_data)
753
754
755
756
757

    def test_4_channel_ndarray_to_pil_image(self):
        def verify_img_data(img_data, mode):
            if mode is None:
                img = transforms.ToPILImage()(img_data)
758
                self.assertEqual(img.mode, 'RGBA')  # default should assume RGBA
759
760
            else:
                img = transforms.ToPILImage(mode=mode)(img_data)
761
                self.assertEqual(img.mode, mode)
762
763
            split = img.split()
            for i in range(4):
764
                self.assertTrue(np.allclose(img_data[:, :, i], split[i]))
765

766
        img_data = torch.ByteTensor(4, 4, 4).random_(0, 255).numpy()
surgan12's avatar
surgan12 committed
767
        for mode in [None, 'RGBA', 'CMYK', 'RGBX']:
768
            verify_img_data(img_data, mode)
769

770
        with self.assertRaises(ValueError):
surgan12's avatar
surgan12 committed
771
            # should raise if we try a mode for 3 or 1 or 2 channel images
772
773
            transforms.ToPILImage(mode='RGB')(img_data)
            transforms.ToPILImage(mode='P')(img_data)
surgan12's avatar
surgan12 committed
774
            transforms.ToPILImage(mode='LA')(img_data)
775

Varun Agrawal's avatar
Varun Agrawal committed
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
    def test_2d_tensor_to_pil_image(self):
        to_tensor = transforms.ToTensor()

        img_data_float = torch.Tensor(4, 4).uniform_()
        img_data_byte = torch.ByteTensor(4, 4).random_(0, 255)
        img_data_short = torch.ShortTensor(4, 4).random_()
        img_data_int = torch.IntTensor(4, 4).random_()

        inputs = [img_data_float, img_data_byte, img_data_short, img_data_int]
        expected_outputs = [img_data_float.mul(255).int().float().div(255).numpy(),
                            img_data_byte.float().div(255.0).numpy(),
                            img_data_short.numpy(),
                            img_data_int.numpy()]
        expected_modes = ['L', 'L', 'I;16', 'I']

        for img_data, expected_output, mode in zip(inputs, expected_outputs, expected_modes):
            for transform in [transforms.ToPILImage(), transforms.ToPILImage(mode=mode)]:
                img = transform(img_data)
794
795
                self.assertEqual(img.mode, mode)
                self.assertTrue(np.allclose(expected_output, to_tensor(img).numpy()))
Varun Agrawal's avatar
Varun Agrawal committed
796
797
798
799
800
801
802
803
804
805
806
807

    def test_2d_ndarray_to_pil_image(self):
        img_data_float = torch.Tensor(4, 4).uniform_().numpy()
        img_data_byte = torch.ByteTensor(4, 4).random_(0, 255).numpy()
        img_data_short = torch.ShortTensor(4, 4).random_().numpy()
        img_data_int = torch.IntTensor(4, 4).random_().numpy()

        inputs = [img_data_float, img_data_byte, img_data_short, img_data_int]
        expected_modes = ['F', 'L', 'I;16', 'I']
        for img_data, mode in zip(inputs, expected_modes):
            for transform in [transforms.ToPILImage(), transforms.ToPILImage(mode=mode)]:
                img = transform(img_data)
808
809
                self.assertEqual(img.mode, mode)
                self.assertTrue(np.allclose(img_data, img))
Varun Agrawal's avatar
Varun Agrawal committed
810
811
812
813
814

    def test_tensor_bad_types_to_pil_image(self):
        with self.assertRaises(ValueError):
            transforms.ToPILImage()(torch.ones(1, 3, 4, 4))

815
    def test_ndarray_bad_types_to_pil_image(self):
816
        trans = transforms.ToPILImage()
817
        with self.assertRaises(TypeError):
818
819
820
821
822
            trans(np.ones([4, 4, 1], np.int64))
            trans(np.ones([4, 4, 1], np.uint16))
            trans(np.ones([4, 4, 1], np.uint32))
            trans(np.ones([4, 4, 1], np.float64))

Varun Agrawal's avatar
Varun Agrawal committed
823
824
825
        with self.assertRaises(ValueError):
            transforms.ToPILImage()(np.ones([1, 4, 4, 3]))

826
827
    @unittest.skipIf(stats is None, 'scipy.stats not available')
    def test_random_vertical_flip(self):
828
829
        random_state = random.getstate()
        random.seed(42)
830
831
832
        img = transforms.ToPILImage()(torch.rand(3, 10, 10))
        vimg = img.transpose(Image.FLIP_TOP_BOTTOM)

833
        num_samples = 250
834
        num_vertical = 0
835
        for _ in range(num_samples):
836
837
838
839
            out = transforms.RandomVerticalFlip()(img)
            if out == vimg:
                num_vertical += 1

840
841
        p_value = stats.binom_test(num_vertical, num_samples, p=0.5)
        random.setstate(random_state)
842
        self.assertGreater(p_value, 0.0001)
843

844
845
846
847
848
849
850
851
852
        num_samples = 250
        num_vertical = 0
        for _ in range(num_samples):
            out = transforms.RandomVerticalFlip(p=0.7)(img)
            if out == vimg:
                num_vertical += 1

        p_value = stats.binom_test(num_vertical, num_samples, p=0.7)
        random.setstate(random_state)
853
        self.assertGreater(p_value, 0.0001)
854

855
856
857
        # Checking if RandomVerticalFlip can be printed as string
        transforms.RandomVerticalFlip().__repr__()

858
859
    @unittest.skipIf(stats is None, 'scipy.stats not available')
    def test_random_horizontal_flip(self):
860
861
        random_state = random.getstate()
        random.seed(42)
862
863
864
        img = transforms.ToPILImage()(torch.rand(3, 10, 10))
        himg = img.transpose(Image.FLIP_LEFT_RIGHT)

865
        num_samples = 250
866
        num_horizontal = 0
867
        for _ in range(num_samples):
868
869
870
871
            out = transforms.RandomHorizontalFlip()(img)
            if out == himg:
                num_horizontal += 1

872
873
        p_value = stats.binom_test(num_horizontal, num_samples, p=0.5)
        random.setstate(random_state)
874
        self.assertGreater(p_value, 0.0001)
875

876
877
878
879
880
881
882
883
884
        num_samples = 250
        num_horizontal = 0
        for _ in range(num_samples):
            out = transforms.RandomHorizontalFlip(p=0.7)(img)
            if out == himg:
                num_horizontal += 1

        p_value = stats.binom_test(num_horizontal, num_samples, p=0.7)
        random.setstate(random_state)
885
        self.assertGreater(p_value, 0.0001)
886

887
888
889
        # Checking if RandomHorizontalFlip can be printed as string
        transforms.RandomHorizontalFlip().__repr__()

890
    @unittest.skipIf(stats is None, 'scipy.stats is not available')
891
892
893
894
895
896
897
898
899
900
901
902
    def test_normalize(self):
        def samples_from_standard_normal(tensor):
            p_value = stats.kstest(list(tensor.view(-1)), 'norm', args=(0, 1)).pvalue
            return p_value > 0.0001

        random_state = random.getstate()
        random.seed(42)
        for channels in [1, 3]:
            img = torch.rand(channels, 10, 10)
            mean = [img[c].mean() for c in range(channels)]
            std = [img[c].std() for c in range(channels)]
            normalized = transforms.Normalize(mean, std)(img)
903
            self.assertTrue(samples_from_standard_normal(normalized))
904
905
        random.setstate(random_state)

906
907
908
        # Checking if Normalize can be printed as string
        transforms.Normalize(mean, std).__repr__()

909
910
911
        # Checking the optional in-place behaviour
        tensor = torch.rand((1, 16, 16))
        tensor_inplace = transforms.Normalize((0.5,), (0.5,), inplace=True)(tensor)
912
        self.assertTrue(torch.equal(tensor, tensor_inplace))
913

914
915
916
917
918
919
920
921
922
    def test_normalize_different_dtype(self):
        for dtype1 in [torch.float32, torch.float64]:
            img = torch.rand(3, 10, 10, dtype=dtype1)
            for dtype2 in [torch.int64, torch.float32, torch.float64]:
                mean = torch.tensor([1, 2, 3], dtype=dtype2)
                std = torch.tensor([1, 2, 1], dtype=dtype2)
                # checks that it doesn't crash
                transforms.functional.normalize(img, mean, std)

923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
    def test_normalize_3d_tensor(self):
        torch.manual_seed(28)
        n_channels = 3
        img_size = 10
        mean = torch.rand(n_channels)
        std = torch.rand(n_channels)
        img = torch.rand(n_channels, img_size, img_size)
        target = F.normalize(img, mean, std).numpy()

        mean_unsqueezed = mean.view(-1, 1, 1)
        std_unsqueezed = std.view(-1, 1, 1)
        result1 = F.normalize(img, mean_unsqueezed, std_unsqueezed)
        result2 = F.normalize(img,
                              mean_unsqueezed.repeat(1, img_size, img_size),
                              std_unsqueezed.repeat(1, img_size, img_size))
        assert_array_almost_equal(target, result1.numpy())
        assert_array_almost_equal(target, result2.numpy())

941
942
943
944
945
946
947
    def test_adjust_brightness(self):
        x_shape = [2, 2, 3]
        x_data = [0, 5, 13, 54, 135, 226, 37, 8, 234, 90, 255, 1]
        x_np = np.array(x_data, dtype=np.uint8).reshape(x_shape)
        x_pil = Image.fromarray(x_np, mode='RGB')

        # test 0
948
        y_pil = F.adjust_brightness(x_pil, 1)
949
        y_np = np.array(y_pil)
950
        self.assertTrue(np.allclose(y_np, x_np))
951
952

        # test 1
953
        y_pil = F.adjust_brightness(x_pil, 0.5)
954
955
956
        y_np = np.array(y_pil)
        y_ans = [0, 2, 6, 27, 67, 113, 18, 4, 117, 45, 127, 0]
        y_ans = np.array(y_ans, dtype=np.uint8).reshape(x_shape)
957
        self.assertTrue(np.allclose(y_np, y_ans))
958
959

        # test 2
960
        y_pil = F.adjust_brightness(x_pil, 2)
961
962
963
        y_np = np.array(y_pil)
        y_ans = [0, 10, 26, 108, 255, 255, 74, 16, 255, 180, 255, 2]
        y_ans = np.array(y_ans, dtype=np.uint8).reshape(x_shape)
964
        self.assertTrue(np.allclose(y_np, y_ans))
965
966
967
968
969
970
971
972

    def test_adjust_contrast(self):
        x_shape = [2, 2, 3]
        x_data = [0, 5, 13, 54, 135, 226, 37, 8, 234, 90, 255, 1]
        x_np = np.array(x_data, dtype=np.uint8).reshape(x_shape)
        x_pil = Image.fromarray(x_np, mode='RGB')

        # test 0
973
        y_pil = F.adjust_contrast(x_pil, 1)
974
        y_np = np.array(y_pil)
975
        self.assertTrue(np.allclose(y_np, x_np))
976
977

        # test 1
978
        y_pil = F.adjust_contrast(x_pil, 0.5)
979
980
981
        y_np = np.array(y_pil)
        y_ans = [43, 45, 49, 70, 110, 156, 61, 47, 160, 88, 170, 43]
        y_ans = np.array(y_ans, dtype=np.uint8).reshape(x_shape)
982
        self.assertTrue(np.allclose(y_np, y_ans))
983
984

        # test 2
985
        y_pil = F.adjust_contrast(x_pil, 2)
986
987
988
        y_np = np.array(y_pil)
        y_ans = [0, 0, 0, 22, 184, 255, 0, 0, 255, 94, 255, 0]
        y_ans = np.array(y_ans, dtype=np.uint8).reshape(x_shape)
989
        self.assertTrue(np.allclose(y_np, y_ans))
990

Francisco Massa's avatar
Francisco Massa committed
991
    @unittest.skipIf(Image.__version__ >= '7', "Temporarily disabled")
992
993
994
995
996
997
998
    def test_adjust_saturation(self):
        x_shape = [2, 2, 3]
        x_data = [0, 5, 13, 54, 135, 226, 37, 8, 234, 90, 255, 1]
        x_np = np.array(x_data, dtype=np.uint8).reshape(x_shape)
        x_pil = Image.fromarray(x_np, mode='RGB')

        # test 0
999
        y_pil = F.adjust_saturation(x_pil, 1)
1000
        y_np = np.array(y_pil)
1001
        self.assertTrue(np.allclose(y_np, x_np))
1002
1003

        # test 1
1004
        y_pil = F.adjust_saturation(x_pil, 0.5)
1005
1006
1007
        y_np = np.array(y_pil)
        y_ans = [2, 4, 8, 87, 128, 173, 39, 25, 138, 133, 215, 88]
        y_ans = np.array(y_ans, dtype=np.uint8).reshape(x_shape)
1008
        self.assertTrue(np.allclose(y_np, y_ans))
1009
1010

        # test 2
1011
        y_pil = F.adjust_saturation(x_pil, 2)
1012
1013
1014
        y_np = np.array(y_pil)
        y_ans = [0, 6, 22, 0, 149, 255, 32, 0, 255, 4, 255, 0]
        y_ans = np.array(y_ans, dtype=np.uint8).reshape(x_shape)
1015
        self.assertTrue(np.allclose(y_np, y_ans))
1016
1017
1018
1019
1020
1021
1022
1023

    def test_adjust_hue(self):
        x_shape = [2, 2, 3]
        x_data = [0, 5, 13, 54, 135, 226, 37, 8, 234, 90, 255, 1]
        x_np = np.array(x_data, dtype=np.uint8).reshape(x_shape)
        x_pil = Image.fromarray(x_np, mode='RGB')

        with self.assertRaises(ValueError):
1024
1025
            F.adjust_hue(x_pil, -0.7)
            F.adjust_hue(x_pil, 1)
1026
1027
1028

        # test 0: almost same as x_data but not exact.
        # probably because hsv <-> rgb floating point ops
1029
        y_pil = F.adjust_hue(x_pil, 0)
1030
1031
1032
        y_np = np.array(y_pil)
        y_ans = [0, 5, 13, 54, 139, 226, 35, 8, 234, 91, 255, 1]
        y_ans = np.array(y_ans, dtype=np.uint8).reshape(x_shape)
1033
        self.assertTrue(np.allclose(y_np, y_ans))
1034
1035

        # test 1
1036
        y_pil = F.adjust_hue(x_pil, 0.25)
1037
1038
1039
        y_np = np.array(y_pil)
        y_ans = [13, 0, 12, 224, 54, 226, 234, 8, 99, 1, 222, 255]
        y_ans = np.array(y_ans, dtype=np.uint8).reshape(x_shape)
1040
        self.assertTrue(np.allclose(y_np, y_ans))
1041
1042

        # test 2
1043
        y_pil = F.adjust_hue(x_pil, -0.25)
1044
1045
1046
        y_np = np.array(y_pil)
        y_ans = [0, 13, 2, 54, 226, 58, 8, 234, 152, 255, 43, 1]
        y_ans = np.array(y_ans, dtype=np.uint8).reshape(x_shape)
1047
        self.assertTrue(np.allclose(y_np, y_ans))
1048
1049
1050
1051
1052
1053
1054
1055

    def test_adjust_gamma(self):
        x_shape = [2, 2, 3]
        x_data = [0, 5, 13, 54, 135, 226, 37, 8, 234, 90, 255, 1]
        x_np = np.array(x_data, dtype=np.uint8).reshape(x_shape)
        x_pil = Image.fromarray(x_np, mode='RGB')

        # test 0
1056
        y_pil = F.adjust_gamma(x_pil, 1)
1057
        y_np = np.array(y_pil)
1058
        self.assertTrue(np.allclose(y_np, x_np))
1059
1060

        # test 1
1061
        y_pil = F.adjust_gamma(x_pil, 0.5)
1062
1063
1064
        y_np = np.array(y_pil)
        y_ans = [0, 35, 57, 117, 185, 240, 97, 45, 244, 151, 255, 15]
        y_ans = np.array(y_ans, dtype=np.uint8).reshape(x_shape)
1065
        self.assertTrue(np.allclose(y_np, y_ans))
1066
1067

        # test 2
1068
        y_pil = F.adjust_gamma(x_pil, 2)
1069
1070
1071
        y_np = np.array(y_pil)
        y_ans = [0, 0, 0, 11, 71, 200, 5, 0, 214, 31, 255, 0]
        y_ans = np.array(y_ans, dtype=np.uint8).reshape(x_shape)
1072
        self.assertTrue(np.allclose(y_np, y_ans))
1073
1074
1075
1076
1077
1078
1079
1080

    def test_adjusts_L_mode(self):
        x_shape = [2, 2, 3]
        x_data = [0, 5, 13, 54, 135, 226, 37, 8, 234, 90, 255, 1]
        x_np = np.array(x_data, dtype=np.uint8).reshape(x_shape)
        x_rgb = Image.fromarray(x_np, mode='RGB')

        x_l = x_rgb.convert('L')
1081
1082
1083
1084
1085
        self.assertEqual(F.adjust_brightness(x_l, 2).mode, 'L')
        self.assertEqual(F.adjust_saturation(x_l, 2).mode, 'L')
        self.assertEqual(F.adjust_contrast(x_l, 2).mode, 'L')
        self.assertEqual(F.adjust_hue(x_l, 0.4).mode, 'L')
        self.assertEqual(F.adjust_gamma(x_l, 0.5).mode, 'L')
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097

    def test_color_jitter(self):
        color_jitter = transforms.ColorJitter(2, 2, 2, 0.1)

        x_shape = [2, 2, 3]
        x_data = [0, 5, 13, 54, 135, 226, 37, 8, 234, 90, 255, 1]
        x_np = np.array(x_data, dtype=np.uint8).reshape(x_shape)
        x_pil = Image.fromarray(x_np, mode='RGB')
        x_pil_2 = x_pil.convert('L')

        for i in range(10):
            y_pil = color_jitter(x_pil)
1098
            self.assertEqual(y_pil.mode, x_pil.mode)
1099
1100

            y_pil_2 = color_jitter(x_pil_2)
1101
            self.assertEqual(y_pil_2.mode, x_pil_2.mode)
1102

1103
1104
1105
        # Checking if ColorJitter can be printed as string
        color_jitter.__repr__()

1106
    def test_linear_transformation(self):
ekka's avatar
ekka committed
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
        num_samples = 1000
        x = torch.randn(num_samples, 3, 10, 10)
        flat_x = x.view(x.size(0), x.size(1) * x.size(2) * x.size(3))
        # compute principal components
        sigma = torch.mm(flat_x.t(), flat_x) / flat_x.size(0)
        u, s, _ = np.linalg.svd(sigma.numpy())
        zca_epsilon = 1e-10  # avoid division by 0
        d = torch.Tensor(np.diag(1. / np.sqrt(s + zca_epsilon)))
        u = torch.Tensor(u)
        principal_components = torch.mm(torch.mm(u, d), u.t())
        mean_vector = (torch.sum(flat_x, dim=0) / flat_x.size(0))
        # initialize whitening matrix
1119
        whitening = transforms.LinearTransformation(principal_components, mean_vector)
ekka's avatar
ekka committed
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
        # estimate covariance and mean using weak law of large number
        num_features = flat_x.size(1)
        cov = 0.0
        mean = 0.0
        for i in x:
            xwhite = whitening(i)
            xwhite = xwhite.view(1, -1).numpy()
            cov += np.dot(xwhite, xwhite.T) / num_features
            mean += np.sum(xwhite) / num_features
        # if rtol for std = 1e-3 then rtol for cov = 2e-3 as std**2 = cov
1130
1131
1132
1133
        self.assertTrue(np.allclose(cov / num_samples, np.identity(1), rtol=2e-3),
                        "cov not close to 1")
        self.assertTrue(np.allclose(mean / num_samples, 0, rtol=1e-3),
                        "mean not close to 0")
ekka's avatar
ekka committed
1134

1135
        # Checking if LinearTransformation can be printed as string
ekka's avatar
ekka committed
1136
1137
        whitening.__repr__()

1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
    def test_rotate(self):
        x = np.zeros((100, 100, 3), dtype=np.uint8)
        x[40, 40] = [255, 255, 255]

        with self.assertRaises(TypeError):
            F.rotate(x, 10)

        img = F.to_pil_image(x)

        result = F.rotate(img, 45)
1148
        self.assertEqual(result.size, (100, 100))
1149
        r, c, ch = np.where(result)
1150
1151
1152
        self.assertTrue(all(x in r for x in [49, 50]))
        self.assertTrue(all(x in c for x in [36]))
        self.assertTrue(all(x in ch for x in [0, 1, 2]))
1153
1154

        result = F.rotate(img, 45, expand=True)
1155
        self.assertEqual(result.size, (142, 142))
1156
        r, c, ch = np.where(result)
1157
1158
1159
        self.assertTrue(all(x in r for x in [70, 71]))
        self.assertTrue(all(x in c for x in [57]))
        self.assertTrue(all(x in ch for x in [0, 1, 2]))
1160
1161

        result = F.rotate(img, 45, center=(40, 40))
1162
        self.assertEqual(result.size, (100, 100))
1163
        r, c, ch = np.where(result)
1164
1165
1166
        self.assertTrue(all(x in r for x in [40]))
        self.assertTrue(all(x in c for x in [40]))
        self.assertTrue(all(x in ch for x in [0, 1, 2]))
1167
1168
1169
1170

        result_a = F.rotate(img, 90)
        result_b = F.rotate(img, -270)

1171
        self.assertTrue(np.all(np.array(result_a) == np.array(result_b)))
1172

Philip Meier's avatar
Philip Meier committed
1173
1174
1175
    def test_rotate_fill(self):
        img = F.to_pil_image(np.ones((100, 100, 3), dtype=np.uint8) * 255, "RGB")

1176
        modes = ("L", "RGB", "F")
Philip Meier's avatar
Philip Meier committed
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
        nums_bands = [len(mode) for mode in modes]
        fill = 127

        for mode, num_bands in zip(modes, nums_bands):
            img_conv = img.convert(mode)
            img_rot = F.rotate(img_conv, 45.0, fill=fill)
            pixel = img_rot.getpixel((0, 0))

            if not isinstance(pixel, tuple):
                pixel = (pixel,)
            self.assertTupleEqual(pixel, tuple([fill] * num_bands))

            for wrong_num_bands in set(nums_bands) - {num_bands}:
                with self.assertRaises(ValueError):
                    F.rotate(img_conv, 45.0, fill=tuple([fill] * wrong_num_bands))

1193
    def test_affine(self):
Francisco Massa's avatar
Francisco Massa committed
1194
        input_img = np.zeros((40, 40, 3), dtype=np.uint8)
1195
        pts = []
Francisco Massa's avatar
Francisco Massa committed
1196
1197
        cnt = [20, 20]
        for pt in [(16, 16), (20, 16), (20, 20)]:
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
            for i in range(-5, 5):
                for j in range(-5, 5):
                    input_img[pt[0] + i, pt[1] + j, :] = [255, 155, 55]
                    pts.append((pt[0] + i, pt[1] + j))
        pts = list(set(pts))

        with self.assertRaises(TypeError):
            F.affine(input_img, 10)

        pil_img = F.to_pil_image(input_img)

        def _to_3x3_inv(inv_result_matrix):
            result_matrix = np.zeros((3, 3))
            result_matrix[:2, :] = np.array(inv_result_matrix).reshape((2, 3))
            result_matrix[2, 2] = 1
            return np.linalg.inv(result_matrix)

        def _test_transformation(a, t, s, sh):
            a_rad = math.radians(a)
ptrblck's avatar
ptrblck committed
1217
            s_rad = [math.radians(sh_) for sh_ in sh]
1218
1219
1220
1221
1222
            cx, cy = cnt
            tx, ty = t
            sx, sy = s_rad
            rot = a_rad

1223
            # 1) Check transformation matrix:
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245
1246
1247
1248
            C = np.array([[1, 0, cx],
                          [0, 1, cy],
                          [0, 0, 1]])
            T = np.array([[1, 0, tx],
                          [0, 1, ty],
                          [0, 0, 1]])
            Cinv = np.linalg.inv(C)

            RS = np.array(
                [[s * math.cos(rot), -s * math.sin(rot), 0],
                 [s * math.sin(rot), s * math.cos(rot), 0],
                 [0, 0, 1]])

            SHx = np.array([[1, -math.tan(sx), 0],
                            [0, 1, 0],
                            [0, 0, 1]])

            SHy = np.array([[1, 0, 0],
                            [-math.tan(sy), 1, 0],
                            [0, 0, 1]])

            RSS = np.matmul(RS, np.matmul(SHy, SHx))

            true_matrix = np.matmul(T, np.matmul(C, np.matmul(RSS, Cinv)))

1249
1250
            result_matrix = _to_3x3_inv(F._get_inverse_affine_matrix(center=cnt, angle=a,
                                                                     translate=t, scale=s, shear=sh))
1251
            self.assertLess(np.sum(np.abs(true_matrix - result_matrix)), 1e-10)
1252
            # 2) Perform inverse mapping:
Francisco Massa's avatar
Francisco Massa committed
1253
            true_result = np.zeros((40, 40, 3), dtype=np.uint8)
1254
1255
1256
1257
1258
1259
1260
1261
1262
1263
            inv_true_matrix = np.linalg.inv(true_matrix)
            for y in range(true_result.shape[0]):
                for x in range(true_result.shape[1]):
                    res = np.dot(inv_true_matrix, [x, y, 1])
                    _x = int(res[0] + 0.5)
                    _y = int(res[1] + 0.5)
                    if 0 <= _x < input_img.shape[1] and 0 <= _y < input_img.shape[0]:
                        true_result[y, x, :] = input_img[_y, _x, :]

            result = F.affine(pil_img, angle=a, translate=t, scale=s, shear=sh)
1264
            self.assertEqual(result.size, pil_img.size)
1265
1266
1267
1268
            # Compute number of different pixels:
            np_result = np.array(result)
            n_diff_pixels = np.sum(np_result != true_result) / 3
            # Accept 3 wrong pixels
1269
1270
1271
            self.assertLess(n_diff_pixels, 3,
                            "a={}, t={}, s={}, sh={}\n".format(a, t, s, sh) +
                            "n diff pixels={}\n".format(np.sum(np.array(result)[:, :, 0] != true_result[:, :, 0])))
1272
1273
1274

        # Test rotation
        a = 45
ptrblck's avatar
ptrblck committed
1275
        _test_transformation(a=a, t=(0, 0), s=1.0, sh=(0.0, 0.0))
1276
1277
1278

        # Test translation
        t = [10, 15]
ptrblck's avatar
ptrblck committed
1279
        _test_transformation(a=0.0, t=t, s=1.0, sh=(0.0, 0.0))
1280
1281
1282

        # Test scale
        s = 1.2
ptrblck's avatar
ptrblck committed
1283
        _test_transformation(a=0.0, t=(0.0, 0.0), s=s, sh=(0.0, 0.0))
1284
1285

        # Test shear
ptrblck's avatar
ptrblck committed
1286
        sh = [45.0, 25.0]
1287
1288
1289
1290
1291
1292
1293
        _test_transformation(a=0.0, t=(0.0, 0.0), s=1.0, sh=sh)

        # Test rotation, scale, translation, shear
        for a in range(-90, 90, 25):
            for t1 in range(-10, 10, 5):
                for s in [0.75, 0.98, 1.0, 1.1, 1.2]:
                    for sh in range(-15, 15, 5):
ptrblck's avatar
ptrblck committed
1294
                        _test_transformation(a=a, t=(t1, t1), s=s, sh=(sh, sh))
1295

1296
1297
1298
1299
1300
1301
1302
1303
1304
    def test_random_rotation(self):

        with self.assertRaises(ValueError):
            transforms.RandomRotation(-0.7)
            transforms.RandomRotation([-0.7])
            transforms.RandomRotation([-0.7, 0, 0.7])

        t = transforms.RandomRotation(10)
        angle = t.get_params(t.degrees)
1305
        self.assertTrue(angle > -10 and angle < 10)
1306
1307
1308

        t = transforms.RandomRotation((-10, 10))
        angle = t.get_params(t.degrees)
1309
        self.assertTrue(angle > -10 and angle < 10)
1310

1311
1312
1313
        # Checking if RandomRotation can be printed as string
        t.__repr__()

1314
1315
1316
1317
1318
1319
1320
1321
1322
1323
1324
1325
1326
1327
1328
1329
1330
1331
1332
    def test_random_affine(self):

        with self.assertRaises(ValueError):
            transforms.RandomAffine(-0.7)
            transforms.RandomAffine([-0.7])
            transforms.RandomAffine([-0.7, 0, 0.7])

            transforms.RandomAffine([-90, 90], translate=2.0)
            transforms.RandomAffine([-90, 90], translate=[-1.0, 1.0])
            transforms.RandomAffine([-90, 90], translate=[-1.0, 0.0, 1.0])

            transforms.RandomAffine([-90, 90], translate=[0.2, 0.2], scale=[0.0])
            transforms.RandomAffine([-90, 90], translate=[0.2, 0.2], scale=[-1.0, 1.0])
            transforms.RandomAffine([-90, 90], translate=[0.2, 0.2], scale=[0.5, -0.5])
            transforms.RandomAffine([-90, 90], translate=[0.2, 0.2], scale=[0.5, 3.0, -0.5])

            transforms.RandomAffine([-90, 90], translate=[0.2, 0.2], scale=[0.5, 0.5], shear=-7)
            transforms.RandomAffine([-90, 90], translate=[0.2, 0.2], scale=[0.5, 0.5], shear=[-10])
            transforms.RandomAffine([-90, 90], translate=[0.2, 0.2], scale=[0.5, 0.5], shear=[-10, 0, 10])
ptrblck's avatar
ptrblck committed
1333
            transforms.RandomAffine([-90, 90], translate=[0.2, 0.2], scale=[0.5, 0.5], shear=[-10, 0, 10, 0, 10])
1334
1335
1336
1337

        x = np.zeros((100, 100, 3), dtype=np.uint8)
        img = F.to_pil_image(x)

ptrblck's avatar
ptrblck committed
1338
        t = transforms.RandomAffine(10, translate=[0.5, 0.3], scale=[0.7, 1.3], shear=[-10, 10, 20, 40])
1339
1340
1341
        for _ in range(100):
            angle, translations, scale, shear = t.get_params(t.degrees, t.translate, t.scale, t.shear,
                                                             img_size=img.size)
1342
1343
1344
1345
1346
1347
1348
1349
            self.assertTrue(-10 < angle < 10)
            self.assertTrue(-img.size[0] * 0.5 <= translations[0] <= img.size[0] * 0.5,
                            "{} vs {}".format(translations[0], img.size[0] * 0.5))
            self.assertTrue(-img.size[1] * 0.5 <= translations[1] <= img.size[1] * 0.5,
                            "{} vs {}".format(translations[1], img.size[1] * 0.5))
            self.assertTrue(0.7 < scale < 1.3)
            self.assertTrue(-10 < shear[0] < 10)
            self.assertTrue(-20 < shear[1] < 40)
1350
1351
1352
1353
1354

        # Checking if RandomAffine can be printed as string
        t.__repr__()

        t = transforms.RandomAffine(10, resample=Image.BILINEAR)
1355
        self.assertIn("Image.BILINEAR", t.__repr__())
1356

1357
1358
1359
1360
1361
1362
1363
1364
1365
1366
1367
1368
1369
1370
1371
    def test_to_grayscale(self):
        """Unit tests for grayscale transform"""

        x_shape = [2, 2, 3]
        x_data = [0, 5, 13, 54, 135, 226, 37, 8, 234, 90, 255, 1]
        x_np = np.array(x_data, dtype=np.uint8).reshape(x_shape)
        x_pil = Image.fromarray(x_np, mode='RGB')
        x_pil_2 = x_pil.convert('L')
        gray_np = np.array(x_pil_2)

        # Test Set: Grayscale an image with desired number of output channels
        # Case 1: RGB -> 1 channel grayscale
        trans1 = transforms.Grayscale(num_output_channels=1)
        gray_pil_1 = trans1(x_pil)
        gray_np_1 = np.array(gray_pil_1)
1372
1373
        self.assertEqual(gray_pil_1.mode, 'L', 'mode should be L')
        self.assertEqual(gray_np_1.shape, tuple(x_shape[0:2]), 'should be 1 channel')
1374
1375
1376
1377
1378
1379
        np.testing.assert_equal(gray_np, gray_np_1)

        # Case 2: RGB -> 3 channel grayscale
        trans2 = transforms.Grayscale(num_output_channels=3)
        gray_pil_2 = trans2(x_pil)
        gray_np_2 = np.array(gray_pil_2)
1380
1381
        self.assertEqual(gray_pil_2.mode, 'RGB', 'mode should be RGB')
        self.assertEqual(gray_np_2.shape, tuple(x_shape), 'should be 3 channel')
1382
1383
1384
1385
1386
1387
1388
1389
        np.testing.assert_equal(gray_np_2[:, :, 0], gray_np_2[:, :, 1])
        np.testing.assert_equal(gray_np_2[:, :, 1], gray_np_2[:, :, 2])
        np.testing.assert_equal(gray_np, gray_np_2[:, :, 0])

        # Case 3: 1 channel grayscale -> 1 channel grayscale
        trans3 = transforms.Grayscale(num_output_channels=1)
        gray_pil_3 = trans3(x_pil_2)
        gray_np_3 = np.array(gray_pil_3)
1390
1391
        self.assertEqual(gray_pil_3.mode, 'L', 'mode should be L')
        self.assertEqual(gray_np_3.shape, tuple(x_shape[0:2]), 'should be 1 channel')
1392
1393
1394
1395
1396
1397
        np.testing.assert_equal(gray_np, gray_np_3)

        # Case 4: 1 channel grayscale -> 3 channel grayscale
        trans4 = transforms.Grayscale(num_output_channels=3)
        gray_pil_4 = trans4(x_pil_2)
        gray_np_4 = np.array(gray_pil_4)
1398
1399
        self.assertEqual(gray_pil_4.mode, 'RGB', 'mode should be RGB')
        self.assertEqual(gray_np_4.shape, tuple(x_shape), 'should be 3 channel')
1400
1401
1402
1403
        np.testing.assert_equal(gray_np_4[:, :, 0], gray_np_4[:, :, 1])
        np.testing.assert_equal(gray_np_4[:, :, 1], gray_np_4[:, :, 2])
        np.testing.assert_equal(gray_np, gray_np_4[:, :, 0])

1404
1405
1406
        # Checking if Grayscale can be printed as string
        trans4.__repr__()

1407
1408
1409
1410
1411
1412
1413
1414
1415
1416
1417
1418
1419
1420
1421
1422
1423
1424
1425
    @unittest.skipIf(stats is None, 'scipy.stats not available')
    def test_random_grayscale(self):
        """Unit tests for random grayscale transform"""

        # Test Set 1: RGB -> 3 channel grayscale
        random_state = random.getstate()
        random.seed(42)
        x_shape = [2, 2, 3]
        x_np = np.random.randint(0, 256, x_shape, np.uint8)
        x_pil = Image.fromarray(x_np, mode='RGB')
        x_pil_2 = x_pil.convert('L')
        gray_np = np.array(x_pil_2)

        num_samples = 250
        num_gray = 0
        for _ in range(num_samples):
            gray_pil_2 = transforms.RandomGrayscale(p=0.5)(x_pil)
            gray_np_2 = np.array(gray_pil_2)
            if np.array_equal(gray_np_2[:, :, 0], gray_np_2[:, :, 1]) and \
1426
1427
                    np.array_equal(gray_np_2[:, :, 1], gray_np_2[:, :, 2]) and \
                    np.array_equal(gray_np, gray_np_2[:, :, 0]):
1428
1429
1430
1431
                num_gray = num_gray + 1

        p_value = stats.binom_test(num_gray, num_samples, p=0.5)
        random.setstate(random_state)
1432
        self.assertGreater(p_value, 0.0001)
1433
1434
1435
1436
1437
1438
1439
1440
1441
1442
1443
1444
1445
1446
1447
1448
1449
1450
1451
1452

        # Test Set 2: grayscale -> 1 channel grayscale
        random_state = random.getstate()
        random.seed(42)
        x_shape = [2, 2, 3]
        x_np = np.random.randint(0, 256, x_shape, np.uint8)
        x_pil = Image.fromarray(x_np, mode='RGB')
        x_pil_2 = x_pil.convert('L')
        gray_np = np.array(x_pil_2)

        num_samples = 250
        num_gray = 0
        for _ in range(num_samples):
            gray_pil_3 = transforms.RandomGrayscale(p=0.5)(x_pil_2)
            gray_np_3 = np.array(gray_pil_3)
            if np.array_equal(gray_np, gray_np_3):
                num_gray = num_gray + 1

        p_value = stats.binom_test(num_gray, num_samples, p=1.0)  # Note: grayscale is always unchanged
        random.setstate(random_state)
1453
        self.assertGreater(p_value, 0.0001)
1454
1455
1456
1457
1458
1459
1460
1461
1462
1463
1464
1465
1466

        # Test set 3: Explicit tests
        x_shape = [2, 2, 3]
        x_data = [0, 5, 13, 54, 135, 226, 37, 8, 234, 90, 255, 1]
        x_np = np.array(x_data, dtype=np.uint8).reshape(x_shape)
        x_pil = Image.fromarray(x_np, mode='RGB')
        x_pil_2 = x_pil.convert('L')
        gray_np = np.array(x_pil_2)

        # Case 3a: RGB -> 3 channel grayscale (grayscaled)
        trans2 = transforms.RandomGrayscale(p=1.0)
        gray_pil_2 = trans2(x_pil)
        gray_np_2 = np.array(gray_pil_2)
1467
1468
        self.assertEqual(gray_pil_2.mode, 'RGB', 'mode should be RGB')
        self.assertEqual(gray_np_2.shape, tuple(x_shape), 'should be 3 channel')
1469
1470
1471
1472
1473
1474
1475
1476
        np.testing.assert_equal(gray_np_2[:, :, 0], gray_np_2[:, :, 1])
        np.testing.assert_equal(gray_np_2[:, :, 1], gray_np_2[:, :, 2])
        np.testing.assert_equal(gray_np, gray_np_2[:, :, 0])

        # Case 3b: RGB -> 3 channel grayscale (unchanged)
        trans2 = transforms.RandomGrayscale(p=0.0)
        gray_pil_2 = trans2(x_pil)
        gray_np_2 = np.array(gray_pil_2)
1477
1478
        self.assertEqual(gray_pil_2.mode, 'RGB', 'mode should be RGB')
        self.assertEqual(gray_np_2.shape, tuple(x_shape), 'should be 3 channel')
1479
1480
1481
1482
1483
1484
        np.testing.assert_equal(x_np, gray_np_2)

        # Case 3c: 1 channel grayscale -> 1 channel grayscale (grayscaled)
        trans3 = transforms.RandomGrayscale(p=1.0)
        gray_pil_3 = trans3(x_pil_2)
        gray_np_3 = np.array(gray_pil_3)
1485
1486
        self.assertEqual(gray_pil_3.mode, 'L', 'mode should be L')
        self.assertEqual(gray_np_3.shape, tuple(x_shape[0:2]), 'should be 1 channel')
1487
1488
1489
1490
1491
1492
        np.testing.assert_equal(gray_np, gray_np_3)

        # Case 3d: 1 channel grayscale -> 1 channel grayscale (unchanged)
        trans3 = transforms.RandomGrayscale(p=0.0)
        gray_pil_3 = trans3(x_pil_2)
        gray_np_3 = np.array(gray_pil_3)
1493
1494
        self.assertEqual(gray_pil_3.mode, 'L', 'mode should be L')
        self.assertEqual(gray_np_3.shape, tuple(x_shape[0:2]), 'should be 1 channel')
1495
1496
        np.testing.assert_equal(gray_np, gray_np_3)

1497
1498
1499
        # Checking if RandomGrayscale can be printed as string
        trans3.__repr__()

1500
1501
1502
    def test_random_erasing(self):
        """Unit tests for random erasing transform"""

1503
        img = torch.rand([3, 60, 60])
1504
1505

        # Test Set 1: Erasing with int value
1506
1507
1508
        img_re = transforms.RandomErasing(value=0.2)
        i, j, h, w, v = img_re.get_params(img, scale=img_re.scale, ratio=img_re.ratio, value=img_re.value)
        img_output = F.erase(img, i, j, h, w, v)
1509
        self.assertEqual(img_output.size(0), 3)
1510
1511
1512
1513
1514
1515

        # Test Set 2: Check if the unerased region is preserved
        orig_unerased = img.clone()
        orig_unerased[:, i:i + h, j:j + w] = 0
        output_unerased = img_output.clone()
        output_unerased[:, i:i + h, j:j + w] = 0
1516
        self.assertTrue(torch.equal(orig_unerased, output_unerased))
1517
1518

        # Test Set 3: Erasing with random value
1519
        img_re = transforms.RandomErasing(value='random')(img)
1520
        self.assertEqual(img_re.size(0), 3)
1521

1522
        # Test Set 4: Erasing with tuple value
1523
        img_re = transforms.RandomErasing(value=(0.2, 0.2, 0.2))(img)
1524
        self.assertEqual(img_re.size(0), 3)
1525

1526
1527
        # Test Set 5: Testing the inplace behaviour
        img_re = transforms.RandomErasing(value=(0.2), inplace=True)(img)
1528
        self.assertTrue(torch.equal(img_re, img))
1529

Zhun Zhong's avatar
Zhun Zhong committed
1530
1531
1532
        # Test Set 6: Checking when no erased region is selected
        img = torch.rand([3, 300, 1])
        img_re = transforms.RandomErasing(ratio=(0.1, 0.2), value='random')(img)
1533
        self.assertTrue(torch.equal(img_re, img))
Zhun Zhong's avatar
Zhun Zhong committed
1534

1535

1536
1537
if __name__ == '__main__':
    unittest.main()