test_transforms.py 63.8 KB
Newer Older
1
import os
2
3
import torch
import torchvision.transforms as transforms
4
import torchvision.transforms.functional as F
5
from torch._utils_internal import get_file_path_2
6
from numpy.testing import assert_array_almost_equal
7
import unittest
8
import math
9
import random
10
import numpy as np
11
12
13
14
15
16
from PIL import Image
try:
    import accimage
except ImportError:
    accimage = None

17
18
19
20
21
try:
    from scipy import stats
except ImportError:
    stats = None

22
23
GRACE_HOPPER = get_file_path_2(
    os.path.dirname(os.path.abspath(__file__)), 'assets', 'grace_hopper_517x606.jpg')
24

25

26
class Tester(unittest.TestCase):
27

28
29
30
31
    def test_crop(self):
        height = random.randint(10, 32) * 2
        width = random.randint(10, 32) * 2
        oheight = random.randint(5, (height - 2) / 2) * 2
32
33
        owidth = random.randint(5, (width - 2) / 2) * 2

34
        img = torch.ones(3, height, width)
35
36
37
        oh1 = (height - oheight) // 2
        ow1 = (width - owidth) // 2
        imgnarrow = img[:, oh1:oh1 + oheight, ow1:ow1 + owidth]
38
39
40
41
42
43
        imgnarrow.fill_(0)
        result = transforms.Compose([
            transforms.ToPILImage(),
            transforms.CenterCrop((oheight, owidth)),
            transforms.ToTensor(),
        ])(img)
44
45
        self.assertEqual(result.sum(), 0,
                         "height: {} width: {} oheight: {} owdith: {}".format(height, width, oheight, owidth))
46
47
48
49
50
51
52
53
        oheight += 1
        owidth += 1
        result = transforms.Compose([
            transforms.ToPILImage(),
            transforms.CenterCrop((oheight, owidth)),
            transforms.ToTensor(),
        ])(img)
        sum1 = result.sum()
54
55
        self.assertGreater(sum1, 1,
                           "height: {} width: {} oheight: {} owdith: {}".format(height, width, oheight, owidth))
56
        oheight += 1
57
        owidth += 1
58
59
60
61
62
63
        result = transforms.Compose([
            transforms.ToPILImage(),
            transforms.CenterCrop((oheight, owidth)),
            transforms.ToTensor(),
        ])(img)
        sum2 = result.sum()
64
65
66
67
        self.assertGreater(sum2, 0,
                           "height: {} width: {} oheight: {} owdith: {}".format(height, width, oheight, owidth))
        self.assertGreater(sum2, sum1,
                           "height: {} width: {} oheight: {} owdith: {}".format(height, width, oheight, owidth))
68

69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
    def test_five_crop(self):
        to_pil_image = transforms.ToPILImage()
        h = random.randint(5, 25)
        w = random.randint(5, 25)
        for single_dim in [True, False]:
            crop_h = random.randint(1, h)
            crop_w = random.randint(1, w)
            if single_dim:
                crop_h = min(crop_h, crop_w)
                crop_w = crop_h
                transform = transforms.FiveCrop(crop_h)
            else:
                transform = transforms.FiveCrop((crop_h, crop_w))

            img = torch.FloatTensor(3, h, w).uniform_()
            results = transform(to_pil_image(img))

86
            self.assertEqual(len(results), 5)
87
            for crop in results:
88
                self.assertEqual(crop.size, (crop_w, crop_h))
89
90
91
92
93
94
95
96

            to_pil_image = transforms.ToPILImage()
            tl = to_pil_image(img[:, 0:crop_h, 0:crop_w])
            tr = to_pil_image(img[:, 0:crop_h, w - crop_w:])
            bl = to_pil_image(img[:, h - crop_h:, 0:crop_w])
            br = to_pil_image(img[:, h - crop_h:, w - crop_w:])
            center = transforms.CenterCrop((crop_h, crop_w))(to_pil_image(img))
            expected_output = (tl, tr, bl, br, center)
97
            self.assertEqual(results, expected_output)
98
99
100
101
102
103
104
105
106
107
108
109

    def test_ten_crop(self):
        to_pil_image = transforms.ToPILImage()
        h = random.randint(5, 25)
        w = random.randint(5, 25)
        for should_vflip in [True, False]:
            for single_dim in [True, False]:
                crop_h = random.randint(1, h)
                crop_w = random.randint(1, w)
                if single_dim:
                    crop_h = min(crop_h, crop_w)
                    crop_w = crop_h
110
111
                    transform = transforms.TenCrop(crop_h,
                                                   vertical_flip=should_vflip)
112
113
                    five_crop = transforms.FiveCrop(crop_h)
                else:
114
115
                    transform = transforms.TenCrop((crop_h, crop_w),
                                                   vertical_flip=should_vflip)
116
117
118
119
120
                    five_crop = transforms.FiveCrop((crop_h, crop_w))

                img = to_pil_image(torch.FloatTensor(3, h, w).uniform_())
                results = transform(img)
                expected_output = five_crop(img)
121
122
123
124
125

                # Checking if FiveCrop and TenCrop can be printed as string
                transform.__repr__()
                five_crop.__repr__()

126
127
128
129
130
131
132
                if should_vflip:
                    vflipped_img = img.transpose(Image.FLIP_TOP_BOTTOM)
                    expected_output += five_crop(vflipped_img)
                else:
                    hflipped_img = img.transpose(Image.FLIP_LEFT_RIGHT)
                    expected_output += five_crop(hflipped_img)

133
134
                self.assertEqual(len(results), 10)
                self.assertEqual(results, expected_output)
135

136
137
138
139
140
141
142
143
    def test_randomresized_params(self):
        height = random.randint(24, 32) * 2
        width = random.randint(24, 32) * 2
        img = torch.ones(3, height, width)
        to_pil_image = transforms.ToPILImage()
        img = to_pil_image(img)
        size = 100
        epsilon = 0.05
144
        min_scale = 0.25
Francisco Massa's avatar
Francisco Massa committed
145
        for _ in range(10):
146
            scale_min = max(round(random.random(), 2), min_scale)
147
            scale_range = (scale_min, scale_min + round(random.random(), 2))
148
            aspect_min = max(round(random.random(), 2), epsilon)
149
150
            aspect_ratio_range = (aspect_min, aspect_min + round(random.random(), 2))
            randresizecrop = transforms.RandomResizedCrop(size, scale_range, aspect_ratio_range)
151
            i, j, h, w = randresizecrop.get_params(img, scale_range, aspect_ratio_range)
152
            aspect_ratio_obtained = w / h
153
154
155
156
157
158
159
            self.assertTrue((min(aspect_ratio_range) - epsilon <= aspect_ratio_obtained and
                             aspect_ratio_obtained <= max(aspect_ratio_range) + epsilon) or
                            aspect_ratio_obtained == 1.0)
            self.assertIsInstance(i, int)
            self.assertIsInstance(j, int)
            self.assertIsInstance(h, int)
            self.assertIsInstance(w, int)
160

161
    def test_randomperspective(self):
Francisco Massa's avatar
Francisco Massa committed
162
        for _ in range(10):
163
164
165
166
167
168
169
170
171
172
            height = random.randint(24, 32) * 2
            width = random.randint(24, 32) * 2
            img = torch.ones(3, height, width)
            to_pil_image = transforms.ToPILImage()
            img = to_pil_image(img)
            perp = transforms.RandomPerspective()
            startpoints, endpoints = perp.get_params(width, height, 0.5)
            tr_img = F.perspective(img, startpoints, endpoints)
            tr_img2 = F.to_tensor(F.perspective(tr_img, endpoints, startpoints))
            tr_img = F.to_tensor(tr_img)
173
174
175
176
            self.assertEqual(img.size[0], width)
            self.assertEqual(img.size[1], height)
            self.assertGreater(torch.nn.functional.mse_loss(tr_img, F.to_tensor(img)) + 0.3,
                               torch.nn.functional.mse_loss(tr_img2, F.to_tensor(img)))
177

178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
    def test_randomperspective_fill(self):
        height = 100
        width = 100
        img = torch.ones(3, height, width)
        to_pil_image = transforms.ToPILImage()
        img = to_pil_image(img)

        modes = ("L", "RGB", "F")
        nums_bands = [len(mode) for mode in modes]
        fill = 127

        for mode, num_bands in zip(modes, nums_bands):
            img_conv = img.convert(mode)
            perspective = transforms.RandomPerspective(p=1, fill=fill)
            tr_img = perspective(img_conv)
            pixel = tr_img.getpixel((0, 0))

            if not isinstance(pixel, tuple):
                pixel = (pixel,)
            self.assertTupleEqual(pixel, tuple([fill] * num_bands))

        for mode, num_bands in zip(modes, nums_bands):
            img_conv = img.convert(mode)
            startpoints, endpoints = transforms.RandomPerspective.get_params(width, height, 0.5)
            tr_img = F.perspective(img_conv, startpoints, endpoints, fill=fill)
            pixel = tr_img.getpixel((0, 0))
204

205
206
207
208
209
210
211
212
            if not isinstance(pixel, tuple):
                pixel = (pixel,)
            self.assertTupleEqual(pixel, tuple([fill] * num_bands))

            for wrong_num_bands in set(nums_bands) - {num_bands}:
                with self.assertRaises(ValueError):
                    F.perspective(img_conv, startpoints, endpoints, fill=tuple([fill] * wrong_num_bands))

213
    def test_resize(self):
214
215
216
        height = random.randint(24, 32) * 2
        width = random.randint(24, 32) * 2
        osize = random.randint(5, 12) * 2
217

218
219
220
        img = torch.ones(3, height, width)
        result = transforms.Compose([
            transforms.ToPILImage(),
221
            transforms.Resize(osize),
222
223
            transforms.ToTensor(),
        ])(img)
224
        self.assertIn(osize, result.size())
225
        if height < width:
226
            self.assertLessEqual(result.size(1), result.size(2))
227
        elif width < height:
228
            self.assertGreaterEqual(result.size(1), result.size(2))
229

230
231
        result = transforms.Compose([
            transforms.ToPILImage(),
232
            transforms.Resize([osize, osize]),
233
234
            transforms.ToTensor(),
        ])(img)
235
236
237
        self.assertIn(osize, result.size())
        self.assertEqual(result.size(1), osize)
        self.assertEqual(result.size(2), osize)
238

239
240
241
242
        oheight = random.randint(5, 12) * 2
        owidth = random.randint(5, 12) * 2
        result = transforms.Compose([
            transforms.ToPILImage(),
243
            transforms.Resize((oheight, owidth)),
244
245
            transforms.ToTensor(),
        ])(img)
246
247
        self.assertEqual(result.size(1), oheight)
        self.assertEqual(result.size(2), owidth)
248
249
250

        result = transforms.Compose([
            transforms.ToPILImage(),
251
            transforms.Resize([oheight, owidth]),
252
253
            transforms.ToTensor(),
        ])(img)
254
255
        self.assertEqual(result.size(1), oheight)
        self.assertEqual(result.size(2), owidth)
256

257
258
259
260
    def test_random_crop(self):
        height = random.randint(10, 32) * 2
        width = random.randint(10, 32) * 2
        oheight = random.randint(5, (height - 2) / 2) * 2
261
        owidth = random.randint(5, (width - 2) / 2) * 2
262
263
264
265
266
267
        img = torch.ones(3, height, width)
        result = transforms.Compose([
            transforms.ToPILImage(),
            transforms.RandomCrop((oheight, owidth)),
            transforms.ToTensor(),
        ])(img)
268
269
        self.assertEqual(result.size(1), oheight)
        self.assertEqual(result.size(2), owidth)
270

271
272
273
274
275
276
        padding = random.randint(1, 20)
        result = transforms.Compose([
            transforms.ToPILImage(),
            transforms.RandomCrop((oheight, owidth), padding=padding),
            transforms.ToTensor(),
        ])(img)
277
278
        self.assertEqual(result.size(1), oheight)
        self.assertEqual(result.size(2), owidth)
279

280
281
282
283
284
        result = transforms.Compose([
            transforms.ToPILImage(),
            transforms.RandomCrop((height, width)),
            transforms.ToTensor()
        ])(img)
285
286
287
        self.assertEqual(result.size(1), height)
        self.assertEqual(result.size(2), width)
        self.assertTrue(np.allclose(img.numpy(), result.numpy()))
288

289
290
291
292
293
        result = transforms.Compose([
            transforms.ToPILImage(),
            transforms.RandomCrop((height + 1, width + 1), pad_if_needed=True),
            transforms.ToTensor(),
        ])(img)
294
295
        self.assertEqual(result.size(1), height + 1)
        self.assertEqual(result.size(2), width + 1)
296

297
298
299
300
301
    def test_pad(self):
        height = random.randint(10, 32) * 2
        width = random.randint(10, 32) * 2
        img = torch.ones(3, height, width)
        padding = random.randint(1, 20)
302
        fill = random.randint(1, 50)
303
304
        result = transforms.Compose([
            transforms.ToPILImage(),
305
            transforms.Pad(padding, fill=fill),
306
307
            transforms.ToTensor(),
        ])(img)
308
309
        self.assertEqual(result.size(1), height + 2 * padding)
        self.assertEqual(result.size(2), width + 2 * padding)
310
311
312
313
314
315
316
317
        # check that all elements in the padded region correspond
        # to the pad value
        fill_v = fill / 255
        eps = 1e-5
        self.assertTrue((result[:, :padding, :] - fill_v).abs().max() < eps)
        self.assertTrue((result[:, :, :padding] - fill_v).abs().max() < eps)
        self.assertRaises(ValueError, transforms.Pad(padding, fill=(1, 2)),
                          transforms.ToPILImage()(img))
Soumith Chintala's avatar
Soumith Chintala committed
318

319
320
321
322
323
324
325
    def test_pad_with_tuple_of_pad_values(self):
        height = random.randint(10, 32) * 2
        width = random.randint(10, 32) * 2
        img = transforms.ToPILImage()(torch.ones(3, height, width))

        padding = tuple([random.randint(1, 20) for _ in range(2)])
        output = transforms.Pad(padding)(img)
326
        self.assertEqual(output.size, (width + padding[0] * 2, height + padding[1] * 2))
327
328
329

        padding = tuple([random.randint(1, 20) for _ in range(4)])
        output = transforms.Pad(padding)(img)
330
331
        self.assertEqual(output.size[0], width + padding[0] + padding[2])
        self.assertEqual(output.size[1], height + padding[1] + padding[3])
332

333
334
335
        # Checking if Padding can be printed as string
        transforms.Pad(padding).__repr__()

336
337
    def test_pad_with_non_constant_padding_modes(self):
        """Unit tests for edge, reflect, symmetric padding"""
vfdev's avatar
vfdev committed
338
        img = torch.zeros(3, 27, 27).byte()
339
340
341
342
343
344
345
346
347
        img[:, :, 0] = 1  # Constant value added to leftmost edge
        img = transforms.ToPILImage()(img)
        img = F.pad(img, 1, (200, 200, 200))

        # pad 3 to all sidess
        edge_padded_img = F.pad(img, 3, padding_mode='edge')
        # First 6 elements of leftmost edge in the middle of the image, values are in order:
        # edge_pad, edge_pad, edge_pad, constant_pad, constant value added to leftmost edge, 0
        edge_middle_slice = np.asarray(edge_padded_img).transpose(2, 0, 1)[0][17][:6]
348
349
        self.assertTrue(np.all(edge_middle_slice == np.asarray([200, 200, 200, 200, 1, 0])))
        self.assertEqual(transforms.ToTensor()(edge_padded_img).size(), (3, 35, 35))
350
351
352
353
354
355

        # Pad 3 to left/right, 2 to top/bottom
        reflect_padded_img = F.pad(img, (3, 2), padding_mode='reflect')
        # First 6 elements of leftmost edge in the middle of the image, values are in order:
        # reflect_pad, reflect_pad, reflect_pad, constant_pad, constant value added to leftmost edge, 0
        reflect_middle_slice = np.asarray(reflect_padded_img).transpose(2, 0, 1)[0][17][:6]
356
357
        self.assertTrue(np.all(reflect_middle_slice == np.asarray([0, 0, 1, 200, 1, 0])))
        self.assertEqual(transforms.ToTensor()(reflect_padded_img).size(), (3, 33, 35))
358
359
360
361
362
363

        # Pad 3 to left, 2 to top, 2 to right, 1 to bottom
        symmetric_padded_img = F.pad(img, (3, 2, 2, 1), padding_mode='symmetric')
        # First 6 elements of leftmost edge in the middle of the image, values are in order:
        # sym_pad, sym_pad, sym_pad, constant_pad, constant value added to leftmost edge, 0
        symmetric_middle_slice = np.asarray(symmetric_padded_img).transpose(2, 0, 1)[0][17][:6]
364
365
        self.assertTrue(np.all(symmetric_middle_slice == np.asarray([0, 1, 200, 200, 1, 0])))
        self.assertEqual(transforms.ToTensor()(symmetric_padded_img).size(), (3, 32, 34))
366

367
    def test_pad_raises_with_invalid_pad_sequence_len(self):
368
369
370
371
372
373
374
375
376
        with self.assertRaises(ValueError):
            transforms.Pad(())

        with self.assertRaises(ValueError):
            transforms.Pad((1, 2, 3))

        with self.assertRaises(ValueError):
            transforms.Pad((1, 2, 3, 4, 5))

Soumith Chintala's avatar
Soumith Chintala committed
377
378
379
380
    def test_lambda(self):
        trans = transforms.Lambda(lambda x: x.add(10))
        x = torch.randn(10)
        y = trans(x)
381
        self.assertTrue(y.equal(torch.add(x, 10)))
Soumith Chintala's avatar
Soumith Chintala committed
382
383
384
385

        trans = transforms.Lambda(lambda x: x.add_(10))
        x = torch.randn(10)
        y = trans(x)
386
        self.assertTrue(y.equal(x))
387

388
389
390
        # Checking if Lambda can be printed as string
        trans.__repr__()

391
    @unittest.skipIf(stats is None, 'scipy.stats not available')
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
    def test_random_apply(self):
        random_state = random.getstate()
        random.seed(42)
        random_apply_transform = transforms.RandomApply(
            [
                transforms.RandomRotation((-45, 45)),
                transforms.RandomHorizontalFlip(),
                transforms.RandomVerticalFlip(),
            ], p=0.75
        )
        img = transforms.ToPILImage()(torch.rand(3, 10, 10))
        num_samples = 250
        num_applies = 0
        for _ in range(num_samples):
            out = random_apply_transform(img)
            if out != img:
                num_applies += 1

        p_value = stats.binom_test(num_applies, num_samples, p=0.75)
        random.setstate(random_state)
412
        self.assertGreater(p_value, 0.0001)
413
414
415
416

        # Checking if RandomApply can be printed as string
        random_apply_transform.__repr__()

417
    @unittest.skipIf(stats is None, 'scipy.stats not available')
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
    def test_random_choice(self):
        random_state = random.getstate()
        random.seed(42)
        random_choice_transform = transforms.RandomChoice(
            [
                transforms.Resize(15),
                transforms.Resize(20),
                transforms.CenterCrop(10)
            ]
        )
        img = transforms.ToPILImage()(torch.rand(3, 25, 25))
        num_samples = 250
        num_resize_15 = 0
        num_resize_20 = 0
        num_crop_10 = 0
        for _ in range(num_samples):
            out = random_choice_transform(img)
            if out.size == (15, 15):
                num_resize_15 += 1
            elif out.size == (20, 20):
                num_resize_20 += 1
            elif out.size == (10, 10):
                num_crop_10 += 1

        p_value = stats.binom_test(num_resize_15, num_samples, p=0.33333)
443
        self.assertGreater(p_value, 0.0001)
444
        p_value = stats.binom_test(num_resize_20, num_samples, p=0.33333)
445
        self.assertGreater(p_value, 0.0001)
446
        p_value = stats.binom_test(num_crop_10, num_samples, p=0.33333)
447
        self.assertGreater(p_value, 0.0001)
448
449
450
451
452

        random.setstate(random_state)
        # Checking if RandomChoice can be printed as string
        random_choice_transform.__repr__()

453
    @unittest.skipIf(stats is None, 'scipy.stats not available')
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
    def test_random_order(self):
        random_state = random.getstate()
        random.seed(42)
        random_order_transform = transforms.RandomOrder(
            [
                transforms.Resize(20),
                transforms.CenterCrop(10)
            ]
        )
        img = transforms.ToPILImage()(torch.rand(3, 25, 25))
        num_samples = 250
        num_normal_order = 0
        resize_crop_out = transforms.CenterCrop(10)(transforms.Resize(20)(img))
        for _ in range(num_samples):
            out = random_order_transform(img)
            if out == resize_crop_out:
                num_normal_order += 1

        p_value = stats.binom_test(num_normal_order, num_samples, p=0.5)
        random.setstate(random_state)
474
        self.assertGreater(p_value, 0.0001)
475
476
477
478

        # Checking if RandomOrder can be printed as string
        random_order_transform.__repr__()

479
    def test_to_tensor(self):
480
        test_channels = [1, 3, 4]
481
482
        height, width = 4, 4
        trans = transforms.ToTensor()
483

484
485
486
487
488
489
490
        with self.assertRaises(TypeError):
            trans(np.random.rand(1, height, width).tolist())

        with self.assertRaises(ValueError):
            trans(np.random.rand(height))
            trans(np.random.rand(1, 1, height, width))

491
492
493
494
        for channels in test_channels:
            input_data = torch.ByteTensor(channels, height, width).random_(0, 255).float().div_(255)
            img = transforms.ToPILImage()(input_data)
            output = trans(img)
495
            self.assertTrue(np.allclose(input_data.numpy(), output.numpy()))
496

497
            ndarray = np.random.randint(low=0, high=255, size=(height, width, channels)).astype(np.uint8)
498
499
            output = trans(ndarray)
            expected_output = ndarray.transpose((2, 0, 1)) / 255.0
500
            self.assertTrue(np.allclose(output.numpy(), expected_output))
501

502
503
504
            ndarray = np.random.rand(height, width, channels).astype(np.float32)
            output = trans(ndarray)
            expected_output = ndarray.transpose((2, 0, 1))
505
            self.assertTrue(np.allclose(output.numpy(), expected_output))
506

507
508
509
510
        # separate test for mode '1' PIL images
        input_data = torch.ByteTensor(1, height, width).bernoulli_()
        img = transforms.ToPILImage()(input_data.mul(255)).convert('1')
        output = trans(img)
511
        self.assertTrue(np.allclose(input_data.numpy(), output.numpy()))
512

513
514
515
516
517
518
519
520
    @unittest.skipIf(accimage is None, 'accimage not available')
    def test_accimage_to_tensor(self):
        trans = transforms.ToTensor()

        expected_output = trans(Image.open(GRACE_HOPPER).convert('RGB'))
        output = trans(accimage.Image(GRACE_HOPPER))

        self.assertEqual(expected_output.size(), output.size())
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
        self.assertTrue(np.allclose(output.numpy(), expected_output.numpy()))

    def test_pil_to_tensor(self):
        test_channels = [1, 3, 4]
        height, width = 4, 4
        trans = transforms.PILToTensor()

        with self.assertRaises(TypeError):
            trans(np.random.rand(1, height, width).tolist())
            trans(np.random.rand(1, height, width))

        for channels in test_channels:
            input_data = torch.ByteTensor(channels, height, width).random_(0, 255)
            img = transforms.ToPILImage()(input_data)
            output = trans(img)
            self.assertTrue(np.allclose(input_data.numpy(), output.numpy()))

            input_data = np.random.randint(low=0, high=255, size=(height, width, channels)).astype(np.uint8)
            img = transforms.ToPILImage()(input_data)
            output = trans(img)
            expected_output = input_data.transpose((2, 0, 1))
            self.assertTrue(np.allclose(output.numpy(), expected_output))

            input_data = torch.as_tensor(np.random.rand(channels, height, width).astype(np.float32))
            img = transforms.ToPILImage()(input_data)  # CHW -> HWC and (* 255).byte()
            output = trans(img)  # HWC -> CHW
            expected_output = (input_data * 255).byte()
            self.assertTrue(np.allclose(output.numpy(), expected_output.numpy()))

        # separate test for mode '1' PIL images
        input_data = torch.ByteTensor(1, height, width).bernoulli_()
        img = transforms.ToPILImage()(input_data.mul(255)).convert('1')
        output = trans(img)
        self.assertTrue(np.allclose(input_data.numpy(), output.numpy()))

    @unittest.skipIf(accimage is None, 'accimage not available')
    def test_accimage_pil_to_tensor(self):
        trans = transforms.PILToTensor()

        expected_output = trans(Image.open(GRACE_HOPPER).convert('RGB'))
        output = trans(accimage.Image(GRACE_HOPPER))

        self.assertEqual(expected_output.size(), output.size())
564
        self.assertTrue(np.allclose(output.numpy(), expected_output.numpy()))
565
566
567
568

    @unittest.skipIf(accimage is None, 'accimage not available')
    def test_accimage_resize(self):
        trans = transforms.Compose([
569
            transforms.Resize(256, interpolation=Image.LINEAR),
570
571
572
            transforms.ToTensor(),
        ])

573
574
575
        # Checking if Compose, Resize and ToTensor can be printed as string
        trans.__repr__()

576
577
578
579
580
581
582
        expected_output = trans(Image.open(GRACE_HOPPER).convert('RGB'))
        output = trans(accimage.Image(GRACE_HOPPER))

        self.assertEqual(expected_output.size(), output.size())
        self.assertLess(np.abs((expected_output - output).mean()), 1e-3)
        self.assertLess((expected_output - output).var(), 1e-5)
        # note the high absolute tolerance
583
        self.assertTrue(np.allclose(output.numpy(), expected_output.numpy(), atol=5e-2))
584
585
586
587
588
589
590
591

    @unittest.skipIf(accimage is None, 'accimage not available')
    def test_accimage_crop(self):
        trans = transforms.Compose([
            transforms.CenterCrop(256),
            transforms.ToTensor(),
        ])

592
593
594
        # Checking if Compose, CenterCrop and ToTensor can be printed as string
        trans.__repr__()

595
596
597
598
        expected_output = trans(Image.open(GRACE_HOPPER).convert('RGB'))
        output = trans(accimage.Image(GRACE_HOPPER))

        self.assertEqual(expected_output.size(), output.size())
599
        self.assertTrue(np.allclose(output.numpy(), expected_output.numpy()))
600

601
    def test_1_channel_tensor_to_pil_image(self):
602
603
        to_tensor = transforms.ToTensor()

604
        img_data_float = torch.Tensor(1, 4, 4).uniform_()
605
606
607
608
        img_data_byte = torch.ByteTensor(1, 4, 4).random_(0, 255)
        img_data_short = torch.ShortTensor(1, 4, 4).random_()
        img_data_int = torch.IntTensor(1, 4, 4).random_()

609
610
611
612
613
614
615
616
617
618
        inputs = [img_data_float, img_data_byte, img_data_short, img_data_int]
        expected_outputs = [img_data_float.mul(255).int().float().div(255).numpy(),
                            img_data_byte.float().div(255.0).numpy(),
                            img_data_short.numpy(),
                            img_data_int.numpy()]
        expected_modes = ['L', 'L', 'I;16', 'I']

        for img_data, expected_output, mode in zip(inputs, expected_outputs, expected_modes):
            for transform in [transforms.ToPILImage(), transforms.ToPILImage(mode=mode)]:
                img = transform(img_data)
619
620
                self.assertEqual(img.mode, mode)
                self.assertTrue(np.allclose(expected_output, to_tensor(img).numpy()))
621
622
        # 'F' mode for torch.FloatTensor
        img_F_mode = transforms.ToPILImage(mode='F')(img_data_float)
623
624
625
        self.assertEqual(img_F_mode.mode, 'F')
        self.assertTrue(np.allclose(np.array(Image.fromarray(img_data_float.squeeze(0).numpy(), mode='F')),
                                    np.array(img_F_mode)))
626
627
628
629
630
631
632
633
634
635
636
637

    def test_1_channel_ndarray_to_pil_image(self):
        img_data_float = torch.Tensor(4, 4, 1).uniform_().numpy()
        img_data_byte = torch.ByteTensor(4, 4, 1).random_(0, 255).numpy()
        img_data_short = torch.ShortTensor(4, 4, 1).random_().numpy()
        img_data_int = torch.IntTensor(4, 4, 1).random_().numpy()

        inputs = [img_data_float, img_data_byte, img_data_short, img_data_int]
        expected_modes = ['F', 'L', 'I;16', 'I']
        for img_data, mode in zip(inputs, expected_modes):
            for transform in [transforms.ToPILImage(), transforms.ToPILImage(mode=mode)]:
                img = transform(img_data)
638
639
                self.assertEqual(img.mode, mode)
                self.assertTrue(np.allclose(img_data[:, :, 0], img))
640

surgan12's avatar
surgan12 committed
641
642
643
644
    def test_2_channel_ndarray_to_pil_image(self):
        def verify_img_data(img_data, mode):
            if mode is None:
                img = transforms.ToPILImage()(img_data)
645
                self.assertEqual(img.mode, 'LA')  # default should assume LA
surgan12's avatar
surgan12 committed
646
647
            else:
                img = transforms.ToPILImage(mode=mode)(img_data)
648
                self.assertEqual(img.mode, mode)
surgan12's avatar
surgan12 committed
649
650
            split = img.split()
            for i in range(2):
651
                self.assertTrue(np.allclose(img_data[:, :, i], split[i]))
surgan12's avatar
surgan12 committed
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668

        img_data = torch.ByteTensor(4, 4, 2).random_(0, 255).numpy()
        for mode in [None, 'LA']:
            verify_img_data(img_data, mode)

        transforms.ToPILImage().__repr__()

        with self.assertRaises(ValueError):
            # should raise if we try a mode for 4 or 1 or 3 channel images
            transforms.ToPILImage(mode='RGBA')(img_data)
            transforms.ToPILImage(mode='P')(img_data)
            transforms.ToPILImage(mode='RGB')(img_data)

    def test_2_channel_tensor_to_pil_image(self):
        def verify_img_data(img_data, expected_output, mode):
            if mode is None:
                img = transforms.ToPILImage()(img_data)
669
                self.assertEqual(img.mode, 'LA')  # default should assume LA
surgan12's avatar
surgan12 committed
670
671
            else:
                img = transforms.ToPILImage(mode=mode)(img_data)
672
                self.assertEqual(img.mode, mode)
surgan12's avatar
surgan12 committed
673
674
            split = img.split()
            for i in range(2):
675
                self.assertTrue(np.allclose(expected_output[i].numpy(), F.to_tensor(split[i]).numpy()))
surgan12's avatar
surgan12 committed
676
677
678
679
680
681
682
683
684
685
686
687

        img_data = torch.Tensor(2, 4, 4).uniform_()
        expected_output = img_data.mul(255).int().float().div(255)
        for mode in [None, 'LA']:
            verify_img_data(img_data, expected_output, mode=mode)

        with self.assertRaises(ValueError):
            # should raise if we try a mode for 4 or 1 or 3 channel images
            transforms.ToPILImage(mode='RGBA')(img_data)
            transforms.ToPILImage(mode='P')(img_data)
            transforms.ToPILImage(mode='RGB')(img_data)

688
689
690
691
    def test_3_channel_tensor_to_pil_image(self):
        def verify_img_data(img_data, expected_output, mode):
            if mode is None:
                img = transforms.ToPILImage()(img_data)
692
                self.assertEqual(img.mode, 'RGB')  # default should assume RGB
693
694
            else:
                img = transforms.ToPILImage(mode=mode)(img_data)
695
                self.assertEqual(img.mode, mode)
696
697
            split = img.split()
            for i in range(3):
698
                self.assertTrue(np.allclose(expected_output[i].numpy(), F.to_tensor(split[i]).numpy()))
699

700
701
702
703
        img_data = torch.Tensor(3, 4, 4).uniform_()
        expected_output = img_data.mul(255).int().float().div(255)
        for mode in [None, 'RGB', 'HSV', 'YCbCr']:
            verify_img_data(img_data, expected_output, mode=mode)
704

705
        with self.assertRaises(ValueError):
surgan12's avatar
surgan12 committed
706
            # should raise if we try a mode for 4 or 1 or 2 channel images
707
708
            transforms.ToPILImage(mode='RGBA')(img_data)
            transforms.ToPILImage(mode='P')(img_data)
surgan12's avatar
surgan12 committed
709
            transforms.ToPILImage(mode='LA')(img_data)
710

Varun Agrawal's avatar
Varun Agrawal committed
711
712
713
        with self.assertRaises(ValueError):
            transforms.ToPILImage()(torch.Tensor(1, 3, 4, 4).uniform_())

714
715
716
717
    def test_3_channel_ndarray_to_pil_image(self):
        def verify_img_data(img_data, mode):
            if mode is None:
                img = transforms.ToPILImage()(img_data)
718
                self.assertEqual(img.mode, 'RGB')  # default should assume RGB
719
720
            else:
                img = transforms.ToPILImage(mode=mode)(img_data)
721
                self.assertEqual(img.mode, mode)
722
723
            split = img.split()
            for i in range(3):
724
                self.assertTrue(np.allclose(img_data[:, :, i], split[i]))
725

726
727
728
729
        img_data = torch.ByteTensor(4, 4, 3).random_(0, 255).numpy()
        for mode in [None, 'RGB', 'HSV', 'YCbCr']:
            verify_img_data(img_data, mode)

730
731
732
        # Checking if ToPILImage can be printed as string
        transforms.ToPILImage().__repr__()

733
        with self.assertRaises(ValueError):
surgan12's avatar
surgan12 committed
734
            # should raise if we try a mode for 4 or 1 or 2 channel images
735
736
            transforms.ToPILImage(mode='RGBA')(img_data)
            transforms.ToPILImage(mode='P')(img_data)
surgan12's avatar
surgan12 committed
737
            transforms.ToPILImage(mode='LA')(img_data)
738
739
740
741
742

    def test_4_channel_tensor_to_pil_image(self):
        def verify_img_data(img_data, expected_output, mode):
            if mode is None:
                img = transforms.ToPILImage()(img_data)
743
                self.assertEqual(img.mode, 'RGBA')  # default should assume RGBA
744
745
            else:
                img = transforms.ToPILImage(mode=mode)(img_data)
746
                self.assertEqual(img.mode, mode)
747
748
749

            split = img.split()
            for i in range(4):
750
                self.assertTrue(np.allclose(expected_output[i].numpy(), F.to_tensor(split[i]).numpy()))
751

752
        img_data = torch.Tensor(4, 4, 4).uniform_()
753
        expected_output = img_data.mul(255).int().float().div(255)
surgan12's avatar
surgan12 committed
754
        for mode in [None, 'RGBA', 'CMYK', 'RGBX']:
755
            verify_img_data(img_data, expected_output, mode)
756

757
        with self.assertRaises(ValueError):
surgan12's avatar
surgan12 committed
758
            # should raise if we try a mode for 3 or 1 or 2 channel images
759
760
            transforms.ToPILImage(mode='RGB')(img_data)
            transforms.ToPILImage(mode='P')(img_data)
surgan12's avatar
surgan12 committed
761
            transforms.ToPILImage(mode='LA')(img_data)
762
763
764
765
766

    def test_4_channel_ndarray_to_pil_image(self):
        def verify_img_data(img_data, mode):
            if mode is None:
                img = transforms.ToPILImage()(img_data)
767
                self.assertEqual(img.mode, 'RGBA')  # default should assume RGBA
768
769
            else:
                img = transforms.ToPILImage(mode=mode)(img_data)
770
                self.assertEqual(img.mode, mode)
771
772
            split = img.split()
            for i in range(4):
773
                self.assertTrue(np.allclose(img_data[:, :, i], split[i]))
774

775
        img_data = torch.ByteTensor(4, 4, 4).random_(0, 255).numpy()
surgan12's avatar
surgan12 committed
776
        for mode in [None, 'RGBA', 'CMYK', 'RGBX']:
777
            verify_img_data(img_data, mode)
778

779
        with self.assertRaises(ValueError):
surgan12's avatar
surgan12 committed
780
            # should raise if we try a mode for 3 or 1 or 2 channel images
781
782
            transforms.ToPILImage(mode='RGB')(img_data)
            transforms.ToPILImage(mode='P')(img_data)
surgan12's avatar
surgan12 committed
783
            transforms.ToPILImage(mode='LA')(img_data)
784

Varun Agrawal's avatar
Varun Agrawal committed
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
    def test_2d_tensor_to_pil_image(self):
        to_tensor = transforms.ToTensor()

        img_data_float = torch.Tensor(4, 4).uniform_()
        img_data_byte = torch.ByteTensor(4, 4).random_(0, 255)
        img_data_short = torch.ShortTensor(4, 4).random_()
        img_data_int = torch.IntTensor(4, 4).random_()

        inputs = [img_data_float, img_data_byte, img_data_short, img_data_int]
        expected_outputs = [img_data_float.mul(255).int().float().div(255).numpy(),
                            img_data_byte.float().div(255.0).numpy(),
                            img_data_short.numpy(),
                            img_data_int.numpy()]
        expected_modes = ['L', 'L', 'I;16', 'I']

        for img_data, expected_output, mode in zip(inputs, expected_outputs, expected_modes):
            for transform in [transforms.ToPILImage(), transforms.ToPILImage(mode=mode)]:
                img = transform(img_data)
803
804
                self.assertEqual(img.mode, mode)
                self.assertTrue(np.allclose(expected_output, to_tensor(img).numpy()))
Varun Agrawal's avatar
Varun Agrawal committed
805
806
807
808
809
810
811
812
813
814
815
816

    def test_2d_ndarray_to_pil_image(self):
        img_data_float = torch.Tensor(4, 4).uniform_().numpy()
        img_data_byte = torch.ByteTensor(4, 4).random_(0, 255).numpy()
        img_data_short = torch.ShortTensor(4, 4).random_().numpy()
        img_data_int = torch.IntTensor(4, 4).random_().numpy()

        inputs = [img_data_float, img_data_byte, img_data_short, img_data_int]
        expected_modes = ['F', 'L', 'I;16', 'I']
        for img_data, mode in zip(inputs, expected_modes):
            for transform in [transforms.ToPILImage(), transforms.ToPILImage(mode=mode)]:
                img = transform(img_data)
817
818
                self.assertEqual(img.mode, mode)
                self.assertTrue(np.allclose(img_data, img))
Varun Agrawal's avatar
Varun Agrawal committed
819
820
821
822
823

    def test_tensor_bad_types_to_pil_image(self):
        with self.assertRaises(ValueError):
            transforms.ToPILImage()(torch.ones(1, 3, 4, 4))

824
    def test_ndarray_bad_types_to_pil_image(self):
825
        trans = transforms.ToPILImage()
826
        with self.assertRaises(TypeError):
827
828
829
830
831
            trans(np.ones([4, 4, 1], np.int64))
            trans(np.ones([4, 4, 1], np.uint16))
            trans(np.ones([4, 4, 1], np.uint32))
            trans(np.ones([4, 4, 1], np.float64))

Varun Agrawal's avatar
Varun Agrawal committed
832
833
834
        with self.assertRaises(ValueError):
            transforms.ToPILImage()(np.ones([1, 4, 4, 3]))

835
836
    @unittest.skipIf(stats is None, 'scipy.stats not available')
    def test_random_vertical_flip(self):
837
838
        random_state = random.getstate()
        random.seed(42)
839
840
841
        img = transforms.ToPILImage()(torch.rand(3, 10, 10))
        vimg = img.transpose(Image.FLIP_TOP_BOTTOM)

842
        num_samples = 250
843
        num_vertical = 0
844
        for _ in range(num_samples):
845
846
847
848
            out = transforms.RandomVerticalFlip()(img)
            if out == vimg:
                num_vertical += 1

849
850
        p_value = stats.binom_test(num_vertical, num_samples, p=0.5)
        random.setstate(random_state)
851
        self.assertGreater(p_value, 0.0001)
852

853
854
855
856
857
858
859
860
861
        num_samples = 250
        num_vertical = 0
        for _ in range(num_samples):
            out = transforms.RandomVerticalFlip(p=0.7)(img)
            if out == vimg:
                num_vertical += 1

        p_value = stats.binom_test(num_vertical, num_samples, p=0.7)
        random.setstate(random_state)
862
        self.assertGreater(p_value, 0.0001)
863

864
865
866
        # Checking if RandomVerticalFlip can be printed as string
        transforms.RandomVerticalFlip().__repr__()

867
868
    @unittest.skipIf(stats is None, 'scipy.stats not available')
    def test_random_horizontal_flip(self):
869
870
        random_state = random.getstate()
        random.seed(42)
871
872
873
        img = transforms.ToPILImage()(torch.rand(3, 10, 10))
        himg = img.transpose(Image.FLIP_LEFT_RIGHT)

874
        num_samples = 250
875
        num_horizontal = 0
876
        for _ in range(num_samples):
877
878
879
880
            out = transforms.RandomHorizontalFlip()(img)
            if out == himg:
                num_horizontal += 1

881
882
        p_value = stats.binom_test(num_horizontal, num_samples, p=0.5)
        random.setstate(random_state)
883
        self.assertGreater(p_value, 0.0001)
884

885
886
887
888
889
890
891
892
893
        num_samples = 250
        num_horizontal = 0
        for _ in range(num_samples):
            out = transforms.RandomHorizontalFlip(p=0.7)(img)
            if out == himg:
                num_horizontal += 1

        p_value = stats.binom_test(num_horizontal, num_samples, p=0.7)
        random.setstate(random_state)
894
        self.assertGreater(p_value, 0.0001)
895

896
897
898
        # Checking if RandomHorizontalFlip can be printed as string
        transforms.RandomHorizontalFlip().__repr__()

899
    @unittest.skipIf(stats is None, 'scipy.stats is not available')
900
901
902
903
904
905
906
907
908
909
910
911
    def test_normalize(self):
        def samples_from_standard_normal(tensor):
            p_value = stats.kstest(list(tensor.view(-1)), 'norm', args=(0, 1)).pvalue
            return p_value > 0.0001

        random_state = random.getstate()
        random.seed(42)
        for channels in [1, 3]:
            img = torch.rand(channels, 10, 10)
            mean = [img[c].mean() for c in range(channels)]
            std = [img[c].std() for c in range(channels)]
            normalized = transforms.Normalize(mean, std)(img)
912
            self.assertTrue(samples_from_standard_normal(normalized))
913
914
        random.setstate(random_state)

915
916
917
        # Checking if Normalize can be printed as string
        transforms.Normalize(mean, std).__repr__()

918
919
920
        # Checking the optional in-place behaviour
        tensor = torch.rand((1, 16, 16))
        tensor_inplace = transforms.Normalize((0.5,), (0.5,), inplace=True)(tensor)
921
        self.assertTrue(torch.equal(tensor, tensor_inplace))
922

923
924
925
926
927
928
929
930
931
    def test_normalize_different_dtype(self):
        for dtype1 in [torch.float32, torch.float64]:
            img = torch.rand(3, 10, 10, dtype=dtype1)
            for dtype2 in [torch.int64, torch.float32, torch.float64]:
                mean = torch.tensor([1, 2, 3], dtype=dtype2)
                std = torch.tensor([1, 2, 1], dtype=dtype2)
                # checks that it doesn't crash
                transforms.functional.normalize(img, mean, std)

932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
    def test_normalize_3d_tensor(self):
        torch.manual_seed(28)
        n_channels = 3
        img_size = 10
        mean = torch.rand(n_channels)
        std = torch.rand(n_channels)
        img = torch.rand(n_channels, img_size, img_size)
        target = F.normalize(img, mean, std).numpy()

        mean_unsqueezed = mean.view(-1, 1, 1)
        std_unsqueezed = std.view(-1, 1, 1)
        result1 = F.normalize(img, mean_unsqueezed, std_unsqueezed)
        result2 = F.normalize(img,
                              mean_unsqueezed.repeat(1, img_size, img_size),
                              std_unsqueezed.repeat(1, img_size, img_size))
        assert_array_almost_equal(target, result1.numpy())
        assert_array_almost_equal(target, result2.numpy())

950
951
952
953
954
955
956
    def test_adjust_brightness(self):
        x_shape = [2, 2, 3]
        x_data = [0, 5, 13, 54, 135, 226, 37, 8, 234, 90, 255, 1]
        x_np = np.array(x_data, dtype=np.uint8).reshape(x_shape)
        x_pil = Image.fromarray(x_np, mode='RGB')

        # test 0
957
        y_pil = F.adjust_brightness(x_pil, 1)
958
        y_np = np.array(y_pil)
959
        self.assertTrue(np.allclose(y_np, x_np))
960
961

        # test 1
962
        y_pil = F.adjust_brightness(x_pil, 0.5)
963
964
965
        y_np = np.array(y_pil)
        y_ans = [0, 2, 6, 27, 67, 113, 18, 4, 117, 45, 127, 0]
        y_ans = np.array(y_ans, dtype=np.uint8).reshape(x_shape)
966
        self.assertTrue(np.allclose(y_np, y_ans))
967
968

        # test 2
969
        y_pil = F.adjust_brightness(x_pil, 2)
970
971
972
        y_np = np.array(y_pil)
        y_ans = [0, 10, 26, 108, 255, 255, 74, 16, 255, 180, 255, 2]
        y_ans = np.array(y_ans, dtype=np.uint8).reshape(x_shape)
973
        self.assertTrue(np.allclose(y_np, y_ans))
974
975
976
977
978
979
980
981

    def test_adjust_contrast(self):
        x_shape = [2, 2, 3]
        x_data = [0, 5, 13, 54, 135, 226, 37, 8, 234, 90, 255, 1]
        x_np = np.array(x_data, dtype=np.uint8).reshape(x_shape)
        x_pil = Image.fromarray(x_np, mode='RGB')

        # test 0
982
        y_pil = F.adjust_contrast(x_pil, 1)
983
        y_np = np.array(y_pil)
984
        self.assertTrue(np.allclose(y_np, x_np))
985
986

        # test 1
987
        y_pil = F.adjust_contrast(x_pil, 0.5)
988
989
990
        y_np = np.array(y_pil)
        y_ans = [43, 45, 49, 70, 110, 156, 61, 47, 160, 88, 170, 43]
        y_ans = np.array(y_ans, dtype=np.uint8).reshape(x_shape)
991
        self.assertTrue(np.allclose(y_np, y_ans))
992
993

        # test 2
994
        y_pil = F.adjust_contrast(x_pil, 2)
995
996
997
        y_np = np.array(y_pil)
        y_ans = [0, 0, 0, 22, 184, 255, 0, 0, 255, 94, 255, 0]
        y_ans = np.array(y_ans, dtype=np.uint8).reshape(x_shape)
998
        self.assertTrue(np.allclose(y_np, y_ans))
999

Francisco Massa's avatar
Francisco Massa committed
1000
    @unittest.skipIf(Image.__version__ >= '7', "Temporarily disabled")
1001
1002
1003
1004
1005
1006
1007
    def test_adjust_saturation(self):
        x_shape = [2, 2, 3]
        x_data = [0, 5, 13, 54, 135, 226, 37, 8, 234, 90, 255, 1]
        x_np = np.array(x_data, dtype=np.uint8).reshape(x_shape)
        x_pil = Image.fromarray(x_np, mode='RGB')

        # test 0
1008
        y_pil = F.adjust_saturation(x_pil, 1)
1009
        y_np = np.array(y_pil)
1010
        self.assertTrue(np.allclose(y_np, x_np))
1011
1012

        # test 1
1013
        y_pil = F.adjust_saturation(x_pil, 0.5)
1014
1015
1016
        y_np = np.array(y_pil)
        y_ans = [2, 4, 8, 87, 128, 173, 39, 25, 138, 133, 215, 88]
        y_ans = np.array(y_ans, dtype=np.uint8).reshape(x_shape)
1017
        self.assertTrue(np.allclose(y_np, y_ans))
1018
1019

        # test 2
1020
        y_pil = F.adjust_saturation(x_pil, 2)
1021
1022
1023
        y_np = np.array(y_pil)
        y_ans = [0, 6, 22, 0, 149, 255, 32, 0, 255, 4, 255, 0]
        y_ans = np.array(y_ans, dtype=np.uint8).reshape(x_shape)
1024
        self.assertTrue(np.allclose(y_np, y_ans))
1025
1026
1027
1028
1029
1030
1031
1032

    def test_adjust_hue(self):
        x_shape = [2, 2, 3]
        x_data = [0, 5, 13, 54, 135, 226, 37, 8, 234, 90, 255, 1]
        x_np = np.array(x_data, dtype=np.uint8).reshape(x_shape)
        x_pil = Image.fromarray(x_np, mode='RGB')

        with self.assertRaises(ValueError):
1033
1034
            F.adjust_hue(x_pil, -0.7)
            F.adjust_hue(x_pil, 1)
1035
1036
1037

        # test 0: almost same as x_data but not exact.
        # probably because hsv <-> rgb floating point ops
1038
        y_pil = F.adjust_hue(x_pil, 0)
1039
1040
1041
        y_np = np.array(y_pil)
        y_ans = [0, 5, 13, 54, 139, 226, 35, 8, 234, 91, 255, 1]
        y_ans = np.array(y_ans, dtype=np.uint8).reshape(x_shape)
1042
        self.assertTrue(np.allclose(y_np, y_ans))
1043
1044

        # test 1
1045
        y_pil = F.adjust_hue(x_pil, 0.25)
1046
1047
1048
        y_np = np.array(y_pil)
        y_ans = [13, 0, 12, 224, 54, 226, 234, 8, 99, 1, 222, 255]
        y_ans = np.array(y_ans, dtype=np.uint8).reshape(x_shape)
1049
        self.assertTrue(np.allclose(y_np, y_ans))
1050
1051

        # test 2
1052
        y_pil = F.adjust_hue(x_pil, -0.25)
1053
1054
1055
        y_np = np.array(y_pil)
        y_ans = [0, 13, 2, 54, 226, 58, 8, 234, 152, 255, 43, 1]
        y_ans = np.array(y_ans, dtype=np.uint8).reshape(x_shape)
1056
        self.assertTrue(np.allclose(y_np, y_ans))
1057
1058
1059
1060
1061
1062
1063
1064

    def test_adjust_gamma(self):
        x_shape = [2, 2, 3]
        x_data = [0, 5, 13, 54, 135, 226, 37, 8, 234, 90, 255, 1]
        x_np = np.array(x_data, dtype=np.uint8).reshape(x_shape)
        x_pil = Image.fromarray(x_np, mode='RGB')

        # test 0
1065
        y_pil = F.adjust_gamma(x_pil, 1)
1066
        y_np = np.array(y_pil)
1067
        self.assertTrue(np.allclose(y_np, x_np))
1068
1069

        # test 1
1070
        y_pil = F.adjust_gamma(x_pil, 0.5)
1071
1072
1073
        y_np = np.array(y_pil)
        y_ans = [0, 35, 57, 117, 185, 240, 97, 45, 244, 151, 255, 15]
        y_ans = np.array(y_ans, dtype=np.uint8).reshape(x_shape)
1074
        self.assertTrue(np.allclose(y_np, y_ans))
1075
1076

        # test 2
1077
        y_pil = F.adjust_gamma(x_pil, 2)
1078
1079
1080
        y_np = np.array(y_pil)
        y_ans = [0, 0, 0, 11, 71, 200, 5, 0, 214, 31, 255, 0]
        y_ans = np.array(y_ans, dtype=np.uint8).reshape(x_shape)
1081
        self.assertTrue(np.allclose(y_np, y_ans))
1082
1083
1084
1085
1086
1087
1088
1089

    def test_adjusts_L_mode(self):
        x_shape = [2, 2, 3]
        x_data = [0, 5, 13, 54, 135, 226, 37, 8, 234, 90, 255, 1]
        x_np = np.array(x_data, dtype=np.uint8).reshape(x_shape)
        x_rgb = Image.fromarray(x_np, mode='RGB')

        x_l = x_rgb.convert('L')
1090
1091
1092
1093
1094
        self.assertEqual(F.adjust_brightness(x_l, 2).mode, 'L')
        self.assertEqual(F.adjust_saturation(x_l, 2).mode, 'L')
        self.assertEqual(F.adjust_contrast(x_l, 2).mode, 'L')
        self.assertEqual(F.adjust_hue(x_l, 0.4).mode, 'L')
        self.assertEqual(F.adjust_gamma(x_l, 0.5).mode, 'L')
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106

    def test_color_jitter(self):
        color_jitter = transforms.ColorJitter(2, 2, 2, 0.1)

        x_shape = [2, 2, 3]
        x_data = [0, 5, 13, 54, 135, 226, 37, 8, 234, 90, 255, 1]
        x_np = np.array(x_data, dtype=np.uint8).reshape(x_shape)
        x_pil = Image.fromarray(x_np, mode='RGB')
        x_pil_2 = x_pil.convert('L')

        for i in range(10):
            y_pil = color_jitter(x_pil)
1107
            self.assertEqual(y_pil.mode, x_pil.mode)
1108
1109

            y_pil_2 = color_jitter(x_pil_2)
1110
            self.assertEqual(y_pil_2.mode, x_pil_2.mode)
1111

1112
1113
1114
        # Checking if ColorJitter can be printed as string
        color_jitter.__repr__()

1115
    def test_linear_transformation(self):
ekka's avatar
ekka committed
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
        num_samples = 1000
        x = torch.randn(num_samples, 3, 10, 10)
        flat_x = x.view(x.size(0), x.size(1) * x.size(2) * x.size(3))
        # compute principal components
        sigma = torch.mm(flat_x.t(), flat_x) / flat_x.size(0)
        u, s, _ = np.linalg.svd(sigma.numpy())
        zca_epsilon = 1e-10  # avoid division by 0
        d = torch.Tensor(np.diag(1. / np.sqrt(s + zca_epsilon)))
        u = torch.Tensor(u)
        principal_components = torch.mm(torch.mm(u, d), u.t())
        mean_vector = (torch.sum(flat_x, dim=0) / flat_x.size(0))
        # initialize whitening matrix
1128
        whitening = transforms.LinearTransformation(principal_components, mean_vector)
ekka's avatar
ekka committed
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
        # estimate covariance and mean using weak law of large number
        num_features = flat_x.size(1)
        cov = 0.0
        mean = 0.0
        for i in x:
            xwhite = whitening(i)
            xwhite = xwhite.view(1, -1).numpy()
            cov += np.dot(xwhite, xwhite.T) / num_features
            mean += np.sum(xwhite) / num_features
        # if rtol for std = 1e-3 then rtol for cov = 2e-3 as std**2 = cov
1139
1140
1141
1142
        self.assertTrue(np.allclose(cov / num_samples, np.identity(1), rtol=2e-3),
                        "cov not close to 1")
        self.assertTrue(np.allclose(mean / num_samples, 0, rtol=1e-3),
                        "mean not close to 0")
ekka's avatar
ekka committed
1143

1144
        # Checking if LinearTransformation can be printed as string
ekka's avatar
ekka committed
1145
1146
        whitening.__repr__()

1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
    def test_rotate(self):
        x = np.zeros((100, 100, 3), dtype=np.uint8)
        x[40, 40] = [255, 255, 255]

        with self.assertRaises(TypeError):
            F.rotate(x, 10)

        img = F.to_pil_image(x)

        result = F.rotate(img, 45)
1157
        self.assertEqual(result.size, (100, 100))
1158
        r, c, ch = np.where(result)
1159
1160
1161
        self.assertTrue(all(x in r for x in [49, 50]))
        self.assertTrue(all(x in c for x in [36]))
        self.assertTrue(all(x in ch for x in [0, 1, 2]))
1162
1163

        result = F.rotate(img, 45, expand=True)
1164
        self.assertEqual(result.size, (142, 142))
1165
        r, c, ch = np.where(result)
1166
1167
1168
        self.assertTrue(all(x in r for x in [70, 71]))
        self.assertTrue(all(x in c for x in [57]))
        self.assertTrue(all(x in ch for x in [0, 1, 2]))
1169
1170

        result = F.rotate(img, 45, center=(40, 40))
1171
        self.assertEqual(result.size, (100, 100))
1172
        r, c, ch = np.where(result)
1173
1174
1175
        self.assertTrue(all(x in r for x in [40]))
        self.assertTrue(all(x in c for x in [40]))
        self.assertTrue(all(x in ch for x in [0, 1, 2]))
1176
1177
1178
1179

        result_a = F.rotate(img, 90)
        result_b = F.rotate(img, -270)

1180
        self.assertTrue(np.all(np.array(result_a) == np.array(result_b)))
1181

Philip Meier's avatar
Philip Meier committed
1182
1183
1184
    def test_rotate_fill(self):
        img = F.to_pil_image(np.ones((100, 100, 3), dtype=np.uint8) * 255, "RGB")

1185
        modes = ("L", "RGB", "F")
Philip Meier's avatar
Philip Meier committed
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
        nums_bands = [len(mode) for mode in modes]
        fill = 127

        for mode, num_bands in zip(modes, nums_bands):
            img_conv = img.convert(mode)
            img_rot = F.rotate(img_conv, 45.0, fill=fill)
            pixel = img_rot.getpixel((0, 0))

            if not isinstance(pixel, tuple):
                pixel = (pixel,)
            self.assertTupleEqual(pixel, tuple([fill] * num_bands))

            for wrong_num_bands in set(nums_bands) - {num_bands}:
                with self.assertRaises(ValueError):
                    F.rotate(img_conv, 45.0, fill=tuple([fill] * wrong_num_bands))

1202
    def test_affine(self):
Francisco Massa's avatar
Francisco Massa committed
1203
        input_img = np.zeros((40, 40, 3), dtype=np.uint8)
1204
        pts = []
Francisco Massa's avatar
Francisco Massa committed
1205
1206
        cnt = [20, 20]
        for pt in [(16, 16), (20, 16), (20, 20)]:
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
            for i in range(-5, 5):
                for j in range(-5, 5):
                    input_img[pt[0] + i, pt[1] + j, :] = [255, 155, 55]
                    pts.append((pt[0] + i, pt[1] + j))
        pts = list(set(pts))

        with self.assertRaises(TypeError):
            F.affine(input_img, 10)

        pil_img = F.to_pil_image(input_img)

        def _to_3x3_inv(inv_result_matrix):
            result_matrix = np.zeros((3, 3))
            result_matrix[:2, :] = np.array(inv_result_matrix).reshape((2, 3))
            result_matrix[2, 2] = 1
            return np.linalg.inv(result_matrix)

        def _test_transformation(a, t, s, sh):
            a_rad = math.radians(a)
ptrblck's avatar
ptrblck committed
1226
            s_rad = [math.radians(sh_) for sh_ in sh]
1227
1228
1229
1230
1231
            cx, cy = cnt
            tx, ty = t
            sx, sy = s_rad
            rot = a_rad

1232
            # 1) Check transformation matrix:
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245
1246
1247
1248
1249
1250
1251
1252
1253
1254
1255
1256
1257
            C = np.array([[1, 0, cx],
                          [0, 1, cy],
                          [0, 0, 1]])
            T = np.array([[1, 0, tx],
                          [0, 1, ty],
                          [0, 0, 1]])
            Cinv = np.linalg.inv(C)

            RS = np.array(
                [[s * math.cos(rot), -s * math.sin(rot), 0],
                 [s * math.sin(rot), s * math.cos(rot), 0],
                 [0, 0, 1]])

            SHx = np.array([[1, -math.tan(sx), 0],
                            [0, 1, 0],
                            [0, 0, 1]])

            SHy = np.array([[1, 0, 0],
                            [-math.tan(sy), 1, 0],
                            [0, 0, 1]])

            RSS = np.matmul(RS, np.matmul(SHy, SHx))

            true_matrix = np.matmul(T, np.matmul(C, np.matmul(RSS, Cinv)))

1258
1259
            result_matrix = _to_3x3_inv(F._get_inverse_affine_matrix(center=cnt, angle=a,
                                                                     translate=t, scale=s, shear=sh))
1260
            self.assertLess(np.sum(np.abs(true_matrix - result_matrix)), 1e-10)
1261
            # 2) Perform inverse mapping:
Francisco Massa's avatar
Francisco Massa committed
1262
            true_result = np.zeros((40, 40, 3), dtype=np.uint8)
1263
1264
1265
1266
1267
1268
1269
1270
1271
1272
            inv_true_matrix = np.linalg.inv(true_matrix)
            for y in range(true_result.shape[0]):
                for x in range(true_result.shape[1]):
                    res = np.dot(inv_true_matrix, [x, y, 1])
                    _x = int(res[0] + 0.5)
                    _y = int(res[1] + 0.5)
                    if 0 <= _x < input_img.shape[1] and 0 <= _y < input_img.shape[0]:
                        true_result[y, x, :] = input_img[_y, _x, :]

            result = F.affine(pil_img, angle=a, translate=t, scale=s, shear=sh)
1273
            self.assertEqual(result.size, pil_img.size)
1274
1275
1276
1277
            # Compute number of different pixels:
            np_result = np.array(result)
            n_diff_pixels = np.sum(np_result != true_result) / 3
            # Accept 3 wrong pixels
1278
1279
1280
            self.assertLess(n_diff_pixels, 3,
                            "a={}, t={}, s={}, sh={}\n".format(a, t, s, sh) +
                            "n diff pixels={}\n".format(np.sum(np.array(result)[:, :, 0] != true_result[:, :, 0])))
1281
1282
1283

        # Test rotation
        a = 45
ptrblck's avatar
ptrblck committed
1284
        _test_transformation(a=a, t=(0, 0), s=1.0, sh=(0.0, 0.0))
1285
1286
1287

        # Test translation
        t = [10, 15]
ptrblck's avatar
ptrblck committed
1288
        _test_transformation(a=0.0, t=t, s=1.0, sh=(0.0, 0.0))
1289
1290
1291

        # Test scale
        s = 1.2
ptrblck's avatar
ptrblck committed
1292
        _test_transformation(a=0.0, t=(0.0, 0.0), s=s, sh=(0.0, 0.0))
1293
1294

        # Test shear
ptrblck's avatar
ptrblck committed
1295
        sh = [45.0, 25.0]
1296
1297
1298
1299
1300
1301
1302
        _test_transformation(a=0.0, t=(0.0, 0.0), s=1.0, sh=sh)

        # Test rotation, scale, translation, shear
        for a in range(-90, 90, 25):
            for t1 in range(-10, 10, 5):
                for s in [0.75, 0.98, 1.0, 1.1, 1.2]:
                    for sh in range(-15, 15, 5):
ptrblck's avatar
ptrblck committed
1303
                        _test_transformation(a=a, t=(t1, t1), s=s, sh=(sh, sh))
1304

1305
1306
1307
1308
1309
1310
1311
1312
1313
    def test_random_rotation(self):

        with self.assertRaises(ValueError):
            transforms.RandomRotation(-0.7)
            transforms.RandomRotation([-0.7])
            transforms.RandomRotation([-0.7, 0, 0.7])

        t = transforms.RandomRotation(10)
        angle = t.get_params(t.degrees)
1314
        self.assertTrue(angle > -10 and angle < 10)
1315
1316
1317

        t = transforms.RandomRotation((-10, 10))
        angle = t.get_params(t.degrees)
1318
        self.assertTrue(angle > -10 and angle < 10)
1319

1320
1321
1322
        # Checking if RandomRotation can be printed as string
        t.__repr__()

1323
1324
1325
1326
1327
1328
1329
1330
1331
1332
1333
1334
1335
1336
1337
1338
1339
1340
1341
    def test_random_affine(self):

        with self.assertRaises(ValueError):
            transforms.RandomAffine(-0.7)
            transforms.RandomAffine([-0.7])
            transforms.RandomAffine([-0.7, 0, 0.7])

            transforms.RandomAffine([-90, 90], translate=2.0)
            transforms.RandomAffine([-90, 90], translate=[-1.0, 1.0])
            transforms.RandomAffine([-90, 90], translate=[-1.0, 0.0, 1.0])

            transforms.RandomAffine([-90, 90], translate=[0.2, 0.2], scale=[0.0])
            transforms.RandomAffine([-90, 90], translate=[0.2, 0.2], scale=[-1.0, 1.0])
            transforms.RandomAffine([-90, 90], translate=[0.2, 0.2], scale=[0.5, -0.5])
            transforms.RandomAffine([-90, 90], translate=[0.2, 0.2], scale=[0.5, 3.0, -0.5])

            transforms.RandomAffine([-90, 90], translate=[0.2, 0.2], scale=[0.5, 0.5], shear=-7)
            transforms.RandomAffine([-90, 90], translate=[0.2, 0.2], scale=[0.5, 0.5], shear=[-10])
            transforms.RandomAffine([-90, 90], translate=[0.2, 0.2], scale=[0.5, 0.5], shear=[-10, 0, 10])
ptrblck's avatar
ptrblck committed
1342
            transforms.RandomAffine([-90, 90], translate=[0.2, 0.2], scale=[0.5, 0.5], shear=[-10, 0, 10, 0, 10])
1343
1344
1345
1346

        x = np.zeros((100, 100, 3), dtype=np.uint8)
        img = F.to_pil_image(x)

ptrblck's avatar
ptrblck committed
1347
        t = transforms.RandomAffine(10, translate=[0.5, 0.3], scale=[0.7, 1.3], shear=[-10, 10, 20, 40])
1348
1349
1350
        for _ in range(100):
            angle, translations, scale, shear = t.get_params(t.degrees, t.translate, t.scale, t.shear,
                                                             img_size=img.size)
1351
1352
1353
1354
1355
1356
1357
1358
            self.assertTrue(-10 < angle < 10)
            self.assertTrue(-img.size[0] * 0.5 <= translations[0] <= img.size[0] * 0.5,
                            "{} vs {}".format(translations[0], img.size[0] * 0.5))
            self.assertTrue(-img.size[1] * 0.5 <= translations[1] <= img.size[1] * 0.5,
                            "{} vs {}".format(translations[1], img.size[1] * 0.5))
            self.assertTrue(0.7 < scale < 1.3)
            self.assertTrue(-10 < shear[0] < 10)
            self.assertTrue(-20 < shear[1] < 40)
1359
1360
1361
1362
1363

        # Checking if RandomAffine can be printed as string
        t.__repr__()

        t = transforms.RandomAffine(10, resample=Image.BILINEAR)
1364
        self.assertIn("Image.BILINEAR", t.__repr__())
1365

1366
1367
1368
1369
1370
1371
1372
1373
1374
1375
1376
1377
1378
1379
1380
    def test_to_grayscale(self):
        """Unit tests for grayscale transform"""

        x_shape = [2, 2, 3]
        x_data = [0, 5, 13, 54, 135, 226, 37, 8, 234, 90, 255, 1]
        x_np = np.array(x_data, dtype=np.uint8).reshape(x_shape)
        x_pil = Image.fromarray(x_np, mode='RGB')
        x_pil_2 = x_pil.convert('L')
        gray_np = np.array(x_pil_2)

        # Test Set: Grayscale an image with desired number of output channels
        # Case 1: RGB -> 1 channel grayscale
        trans1 = transforms.Grayscale(num_output_channels=1)
        gray_pil_1 = trans1(x_pil)
        gray_np_1 = np.array(gray_pil_1)
1381
1382
        self.assertEqual(gray_pil_1.mode, 'L', 'mode should be L')
        self.assertEqual(gray_np_1.shape, tuple(x_shape[0:2]), 'should be 1 channel')
1383
1384
1385
1386
1387
1388
        np.testing.assert_equal(gray_np, gray_np_1)

        # Case 2: RGB -> 3 channel grayscale
        trans2 = transforms.Grayscale(num_output_channels=3)
        gray_pil_2 = trans2(x_pil)
        gray_np_2 = np.array(gray_pil_2)
1389
1390
        self.assertEqual(gray_pil_2.mode, 'RGB', 'mode should be RGB')
        self.assertEqual(gray_np_2.shape, tuple(x_shape), 'should be 3 channel')
1391
1392
1393
1394
1395
1396
1397
1398
        np.testing.assert_equal(gray_np_2[:, :, 0], gray_np_2[:, :, 1])
        np.testing.assert_equal(gray_np_2[:, :, 1], gray_np_2[:, :, 2])
        np.testing.assert_equal(gray_np, gray_np_2[:, :, 0])

        # Case 3: 1 channel grayscale -> 1 channel grayscale
        trans3 = transforms.Grayscale(num_output_channels=1)
        gray_pil_3 = trans3(x_pil_2)
        gray_np_3 = np.array(gray_pil_3)
1399
1400
        self.assertEqual(gray_pil_3.mode, 'L', 'mode should be L')
        self.assertEqual(gray_np_3.shape, tuple(x_shape[0:2]), 'should be 1 channel')
1401
1402
1403
1404
1405
1406
        np.testing.assert_equal(gray_np, gray_np_3)

        # Case 4: 1 channel grayscale -> 3 channel grayscale
        trans4 = transforms.Grayscale(num_output_channels=3)
        gray_pil_4 = trans4(x_pil_2)
        gray_np_4 = np.array(gray_pil_4)
1407
1408
        self.assertEqual(gray_pil_4.mode, 'RGB', 'mode should be RGB')
        self.assertEqual(gray_np_4.shape, tuple(x_shape), 'should be 3 channel')
1409
1410
1411
1412
        np.testing.assert_equal(gray_np_4[:, :, 0], gray_np_4[:, :, 1])
        np.testing.assert_equal(gray_np_4[:, :, 1], gray_np_4[:, :, 2])
        np.testing.assert_equal(gray_np, gray_np_4[:, :, 0])

1413
1414
1415
        # Checking if Grayscale can be printed as string
        trans4.__repr__()

1416
1417
1418
1419
1420
1421
1422
1423
1424
1425
1426
1427
1428
1429
1430
1431
1432
1433
1434
    @unittest.skipIf(stats is None, 'scipy.stats not available')
    def test_random_grayscale(self):
        """Unit tests for random grayscale transform"""

        # Test Set 1: RGB -> 3 channel grayscale
        random_state = random.getstate()
        random.seed(42)
        x_shape = [2, 2, 3]
        x_np = np.random.randint(0, 256, x_shape, np.uint8)
        x_pil = Image.fromarray(x_np, mode='RGB')
        x_pil_2 = x_pil.convert('L')
        gray_np = np.array(x_pil_2)

        num_samples = 250
        num_gray = 0
        for _ in range(num_samples):
            gray_pil_2 = transforms.RandomGrayscale(p=0.5)(x_pil)
            gray_np_2 = np.array(gray_pil_2)
            if np.array_equal(gray_np_2[:, :, 0], gray_np_2[:, :, 1]) and \
1435
1436
                    np.array_equal(gray_np_2[:, :, 1], gray_np_2[:, :, 2]) and \
                    np.array_equal(gray_np, gray_np_2[:, :, 0]):
1437
1438
1439
1440
                num_gray = num_gray + 1

        p_value = stats.binom_test(num_gray, num_samples, p=0.5)
        random.setstate(random_state)
1441
        self.assertGreater(p_value, 0.0001)
1442
1443
1444
1445
1446
1447
1448
1449
1450
1451
1452
1453
1454
1455
1456
1457
1458
1459
1460
1461

        # Test Set 2: grayscale -> 1 channel grayscale
        random_state = random.getstate()
        random.seed(42)
        x_shape = [2, 2, 3]
        x_np = np.random.randint(0, 256, x_shape, np.uint8)
        x_pil = Image.fromarray(x_np, mode='RGB')
        x_pil_2 = x_pil.convert('L')
        gray_np = np.array(x_pil_2)

        num_samples = 250
        num_gray = 0
        for _ in range(num_samples):
            gray_pil_3 = transforms.RandomGrayscale(p=0.5)(x_pil_2)
            gray_np_3 = np.array(gray_pil_3)
            if np.array_equal(gray_np, gray_np_3):
                num_gray = num_gray + 1

        p_value = stats.binom_test(num_gray, num_samples, p=1.0)  # Note: grayscale is always unchanged
        random.setstate(random_state)
1462
        self.assertGreater(p_value, 0.0001)
1463
1464
1465
1466
1467
1468
1469
1470
1471
1472
1473
1474
1475

        # Test set 3: Explicit tests
        x_shape = [2, 2, 3]
        x_data = [0, 5, 13, 54, 135, 226, 37, 8, 234, 90, 255, 1]
        x_np = np.array(x_data, dtype=np.uint8).reshape(x_shape)
        x_pil = Image.fromarray(x_np, mode='RGB')
        x_pil_2 = x_pil.convert('L')
        gray_np = np.array(x_pil_2)

        # Case 3a: RGB -> 3 channel grayscale (grayscaled)
        trans2 = transforms.RandomGrayscale(p=1.0)
        gray_pil_2 = trans2(x_pil)
        gray_np_2 = np.array(gray_pil_2)
1476
1477
        self.assertEqual(gray_pil_2.mode, 'RGB', 'mode should be RGB')
        self.assertEqual(gray_np_2.shape, tuple(x_shape), 'should be 3 channel')
1478
1479
1480
1481
1482
1483
1484
1485
        np.testing.assert_equal(gray_np_2[:, :, 0], gray_np_2[:, :, 1])
        np.testing.assert_equal(gray_np_2[:, :, 1], gray_np_2[:, :, 2])
        np.testing.assert_equal(gray_np, gray_np_2[:, :, 0])

        # Case 3b: RGB -> 3 channel grayscale (unchanged)
        trans2 = transforms.RandomGrayscale(p=0.0)
        gray_pil_2 = trans2(x_pil)
        gray_np_2 = np.array(gray_pil_2)
1486
1487
        self.assertEqual(gray_pil_2.mode, 'RGB', 'mode should be RGB')
        self.assertEqual(gray_np_2.shape, tuple(x_shape), 'should be 3 channel')
1488
1489
1490
1491
1492
1493
        np.testing.assert_equal(x_np, gray_np_2)

        # Case 3c: 1 channel grayscale -> 1 channel grayscale (grayscaled)
        trans3 = transforms.RandomGrayscale(p=1.0)
        gray_pil_3 = trans3(x_pil_2)
        gray_np_3 = np.array(gray_pil_3)
1494
1495
        self.assertEqual(gray_pil_3.mode, 'L', 'mode should be L')
        self.assertEqual(gray_np_3.shape, tuple(x_shape[0:2]), 'should be 1 channel')
1496
1497
1498
1499
1500
1501
        np.testing.assert_equal(gray_np, gray_np_3)

        # Case 3d: 1 channel grayscale -> 1 channel grayscale (unchanged)
        trans3 = transforms.RandomGrayscale(p=0.0)
        gray_pil_3 = trans3(x_pil_2)
        gray_np_3 = np.array(gray_pil_3)
1502
1503
        self.assertEqual(gray_pil_3.mode, 'L', 'mode should be L')
        self.assertEqual(gray_np_3.shape, tuple(x_shape[0:2]), 'should be 1 channel')
1504
1505
        np.testing.assert_equal(gray_np, gray_np_3)

1506
1507
1508
        # Checking if RandomGrayscale can be printed as string
        trans3.__repr__()

1509
1510
1511
    def test_random_erasing(self):
        """Unit tests for random erasing transform"""

1512
        img = torch.rand([3, 60, 60])
1513
1514

        # Test Set 1: Erasing with int value
1515
1516
1517
        img_re = transforms.RandomErasing(value=0.2)
        i, j, h, w, v = img_re.get_params(img, scale=img_re.scale, ratio=img_re.ratio, value=img_re.value)
        img_output = F.erase(img, i, j, h, w, v)
1518
        self.assertEqual(img_output.size(0), 3)
1519
1520
1521
1522
1523
1524

        # Test Set 2: Check if the unerased region is preserved
        orig_unerased = img.clone()
        orig_unerased[:, i:i + h, j:j + w] = 0
        output_unerased = img_output.clone()
        output_unerased[:, i:i + h, j:j + w] = 0
1525
        self.assertTrue(torch.equal(orig_unerased, output_unerased))
1526
1527

        # Test Set 3: Erasing with random value
1528
        img_re = transforms.RandomErasing(value='random')(img)
1529
        self.assertEqual(img_re.size(0), 3)
1530

1531
        # Test Set 4: Erasing with tuple value
1532
        img_re = transforms.RandomErasing(value=(0.2, 0.2, 0.2))(img)
1533
        self.assertEqual(img_re.size(0), 3)
1534

1535
1536
        # Test Set 5: Testing the inplace behaviour
        img_re = transforms.RandomErasing(value=(0.2), inplace=True)(img)
1537
        self.assertTrue(torch.equal(img_re, img))
1538

Zhun Zhong's avatar
Zhun Zhong committed
1539
1540
1541
        # Test Set 6: Checking when no erased region is selected
        img = torch.rand([3, 300, 1])
        img_re = transforms.RandomErasing(ratio=(0.1, 0.2), value='random')(img)
1542
        self.assertTrue(torch.equal(img_re, img))
Zhun Zhong's avatar
Zhun Zhong committed
1543

1544

1545
1546
if __name__ == '__main__':
    unittest.main()