test_transforms.py 61.4 KB
Newer Older
1
import os
2
3
import torch
import torchvision.transforms as transforms
4
import torchvision.transforms.functional as F
5
from torch._utils_internal import get_file_path_2
6
from numpy.testing import assert_array_almost_equal
7
import unittest
8
import math
9
import random
10
import numpy as np
11
12
13
14
15
16
from PIL import Image
try:
    import accimage
except ImportError:
    accimage = None

17
18
19
20
21
try:
    from scipy import stats
except ImportError:
    stats = None

22
23
GRACE_HOPPER = get_file_path_2(
    os.path.dirname(os.path.abspath(__file__)), 'assets', 'grace_hopper_517x606.jpg')
24

25

26
class Tester(unittest.TestCase):
27

28
29
30
31
    def test_crop(self):
        height = random.randint(10, 32) * 2
        width = random.randint(10, 32) * 2
        oheight = random.randint(5, (height - 2) / 2) * 2
32
33
        owidth = random.randint(5, (width - 2) / 2) * 2

34
        img = torch.ones(3, height, width)
35
36
37
        oh1 = (height - oheight) // 2
        ow1 = (width - owidth) // 2
        imgnarrow = img[:, oh1:oh1 + oheight, ow1:ow1 + owidth]
38
39
40
41
42
43
        imgnarrow.fill_(0)
        result = transforms.Compose([
            transforms.ToPILImage(),
            transforms.CenterCrop((oheight, owidth)),
            transforms.ToTensor(),
        ])(img)
44
45
        self.assertEqual(result.sum(), 0,
                         "height: {} width: {} oheight: {} owdith: {}".format(height, width, oheight, owidth))
46
47
48
49
50
51
52
53
        oheight += 1
        owidth += 1
        result = transforms.Compose([
            transforms.ToPILImage(),
            transforms.CenterCrop((oheight, owidth)),
            transforms.ToTensor(),
        ])(img)
        sum1 = result.sum()
54
55
        self.assertGreater(sum1, 1,
                           "height: {} width: {} oheight: {} owdith: {}".format(height, width, oheight, owidth))
56
        oheight += 1
57
        owidth += 1
58
59
60
61
62
63
        result = transforms.Compose([
            transforms.ToPILImage(),
            transforms.CenterCrop((oheight, owidth)),
            transforms.ToTensor(),
        ])(img)
        sum2 = result.sum()
64
65
66
67
        self.assertGreater(sum2, 0,
                           "height: {} width: {} oheight: {} owdith: {}".format(height, width, oheight, owidth))
        self.assertGreater(sum2, sum1,
                           "height: {} width: {} oheight: {} owdith: {}".format(height, width, oheight, owidth))
68

69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
    def test_five_crop(self):
        to_pil_image = transforms.ToPILImage()
        h = random.randint(5, 25)
        w = random.randint(5, 25)
        for single_dim in [True, False]:
            crop_h = random.randint(1, h)
            crop_w = random.randint(1, w)
            if single_dim:
                crop_h = min(crop_h, crop_w)
                crop_w = crop_h
                transform = transforms.FiveCrop(crop_h)
            else:
                transform = transforms.FiveCrop((crop_h, crop_w))

            img = torch.FloatTensor(3, h, w).uniform_()
            results = transform(to_pil_image(img))

86
            self.assertEqual(len(results), 5)
87
            for crop in results:
88
                self.assertEqual(crop.size, (crop_w, crop_h))
89
90
91
92
93
94
95
96

            to_pil_image = transforms.ToPILImage()
            tl = to_pil_image(img[:, 0:crop_h, 0:crop_w])
            tr = to_pil_image(img[:, 0:crop_h, w - crop_w:])
            bl = to_pil_image(img[:, h - crop_h:, 0:crop_w])
            br = to_pil_image(img[:, h - crop_h:, w - crop_w:])
            center = transforms.CenterCrop((crop_h, crop_w))(to_pil_image(img))
            expected_output = (tl, tr, bl, br, center)
97
            self.assertEqual(results, expected_output)
98
99
100
101
102
103
104
105
106
107
108
109

    def test_ten_crop(self):
        to_pil_image = transforms.ToPILImage()
        h = random.randint(5, 25)
        w = random.randint(5, 25)
        for should_vflip in [True, False]:
            for single_dim in [True, False]:
                crop_h = random.randint(1, h)
                crop_w = random.randint(1, w)
                if single_dim:
                    crop_h = min(crop_h, crop_w)
                    crop_w = crop_h
110
111
                    transform = transforms.TenCrop(crop_h,
                                                   vertical_flip=should_vflip)
112
113
                    five_crop = transforms.FiveCrop(crop_h)
                else:
114
115
                    transform = transforms.TenCrop((crop_h, crop_w),
                                                   vertical_flip=should_vflip)
116
117
118
119
120
                    five_crop = transforms.FiveCrop((crop_h, crop_w))

                img = to_pil_image(torch.FloatTensor(3, h, w).uniform_())
                results = transform(img)
                expected_output = five_crop(img)
121
122
123
124
125

                # Checking if FiveCrop and TenCrop can be printed as string
                transform.__repr__()
                five_crop.__repr__()

126
127
128
129
130
131
132
                if should_vflip:
                    vflipped_img = img.transpose(Image.FLIP_TOP_BOTTOM)
                    expected_output += five_crop(vflipped_img)
                else:
                    hflipped_img = img.transpose(Image.FLIP_LEFT_RIGHT)
                    expected_output += five_crop(hflipped_img)

133
134
                self.assertEqual(len(results), 10)
                self.assertEqual(results, expected_output)
135

136
137
138
139
140
141
142
143
    def test_randomresized_params(self):
        height = random.randint(24, 32) * 2
        width = random.randint(24, 32) * 2
        img = torch.ones(3, height, width)
        to_pil_image = transforms.ToPILImage()
        img = to_pil_image(img)
        size = 100
        epsilon = 0.05
144
        min_scale = 0.25
Francisco Massa's avatar
Francisco Massa committed
145
        for _ in range(10):
146
            scale_min = max(round(random.random(), 2), min_scale)
147
            scale_range = (scale_min, scale_min + round(random.random(), 2))
148
            aspect_min = max(round(random.random(), 2), epsilon)
149
150
            aspect_ratio_range = (aspect_min, aspect_min + round(random.random(), 2))
            randresizecrop = transforms.RandomResizedCrop(size, scale_range, aspect_ratio_range)
151
            i, j, h, w = randresizecrop.get_params(img, scale_range, aspect_ratio_range)
152
            aspect_ratio_obtained = w / h
153
154
155
156
157
158
159
            self.assertTrue((min(aspect_ratio_range) - epsilon <= aspect_ratio_obtained and
                             aspect_ratio_obtained <= max(aspect_ratio_range) + epsilon) or
                            aspect_ratio_obtained == 1.0)
            self.assertIsInstance(i, int)
            self.assertIsInstance(j, int)
            self.assertIsInstance(h, int)
            self.assertIsInstance(w, int)
160

161
    def test_randomperspective(self):
Francisco Massa's avatar
Francisco Massa committed
162
        for _ in range(10):
163
164
165
166
167
168
169
170
171
172
            height = random.randint(24, 32) * 2
            width = random.randint(24, 32) * 2
            img = torch.ones(3, height, width)
            to_pil_image = transforms.ToPILImage()
            img = to_pil_image(img)
            perp = transforms.RandomPerspective()
            startpoints, endpoints = perp.get_params(width, height, 0.5)
            tr_img = F.perspective(img, startpoints, endpoints)
            tr_img2 = F.to_tensor(F.perspective(tr_img, endpoints, startpoints))
            tr_img = F.to_tensor(tr_img)
173
174
175
176
            self.assertEqual(img.size[0], width)
            self.assertEqual(img.size[1], height)
            self.assertGreater(torch.nn.functional.mse_loss(tr_img, F.to_tensor(img)) + 0.3,
                               torch.nn.functional.mse_loss(tr_img2, F.to_tensor(img)))
177

178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
    def test_randomperspective_fill(self):
        height = 100
        width = 100
        img = torch.ones(3, height, width)
        to_pil_image = transforms.ToPILImage()
        img = to_pil_image(img)

        modes = ("L", "RGB", "F")
        nums_bands = [len(mode) for mode in modes]
        fill = 127

        for mode, num_bands in zip(modes, nums_bands):
            img_conv = img.convert(mode)
            perspective = transforms.RandomPerspective(p=1, fill=fill)
            tr_img = perspective(img_conv)
            pixel = tr_img.getpixel((0, 0))

            if not isinstance(pixel, tuple):
                pixel = (pixel,)
            self.assertTupleEqual(pixel, tuple([fill] * num_bands))

        for mode, num_bands in zip(modes, nums_bands):
            img_conv = img.convert(mode)
            startpoints, endpoints = transforms.RandomPerspective.get_params(width, height, 0.5)
            tr_img = F.perspective(img_conv, startpoints, endpoints, fill=fill)
            pixel = tr_img.getpixel((0, 0))
204

205
206
207
208
209
210
211
212
            if not isinstance(pixel, tuple):
                pixel = (pixel,)
            self.assertTupleEqual(pixel, tuple([fill] * num_bands))

            for wrong_num_bands in set(nums_bands) - {num_bands}:
                with self.assertRaises(ValueError):
                    F.perspective(img_conv, startpoints, endpoints, fill=tuple([fill] * wrong_num_bands))

213
    def test_resize(self):
214
215
216
        height = random.randint(24, 32) * 2
        width = random.randint(24, 32) * 2
        osize = random.randint(5, 12) * 2
217

218
219
220
        img = torch.ones(3, height, width)
        result = transforms.Compose([
            transforms.ToPILImage(),
221
            transforms.Resize(osize),
222
223
            transforms.ToTensor(),
        ])(img)
224
        self.assertIn(osize, result.size())
225
        if height < width:
226
            self.assertLessEqual(result.size(1), result.size(2))
227
        elif width < height:
228
            self.assertGreaterEqual(result.size(1), result.size(2))
229

230
231
        result = transforms.Compose([
            transforms.ToPILImage(),
232
            transforms.Resize([osize, osize]),
233
234
            transforms.ToTensor(),
        ])(img)
235
236
237
        self.assertIn(osize, result.size())
        self.assertEqual(result.size(1), osize)
        self.assertEqual(result.size(2), osize)
238

239
240
241
242
        oheight = random.randint(5, 12) * 2
        owidth = random.randint(5, 12) * 2
        result = transforms.Compose([
            transforms.ToPILImage(),
243
            transforms.Resize((oheight, owidth)),
244
245
            transforms.ToTensor(),
        ])(img)
246
247
        self.assertEqual(result.size(1), oheight)
        self.assertEqual(result.size(2), owidth)
248
249
250

        result = transforms.Compose([
            transforms.ToPILImage(),
251
            transforms.Resize([oheight, owidth]),
252
253
            transforms.ToTensor(),
        ])(img)
254
255
        self.assertEqual(result.size(1), oheight)
        self.assertEqual(result.size(2), owidth)
256

257
258
259
260
    def test_random_crop(self):
        height = random.randint(10, 32) * 2
        width = random.randint(10, 32) * 2
        oheight = random.randint(5, (height - 2) / 2) * 2
261
        owidth = random.randint(5, (width - 2) / 2) * 2
262
263
264
265
266
267
        img = torch.ones(3, height, width)
        result = transforms.Compose([
            transforms.ToPILImage(),
            transforms.RandomCrop((oheight, owidth)),
            transforms.ToTensor(),
        ])(img)
268
269
        self.assertEqual(result.size(1), oheight)
        self.assertEqual(result.size(2), owidth)
270

271
272
273
274
275
276
        padding = random.randint(1, 20)
        result = transforms.Compose([
            transforms.ToPILImage(),
            transforms.RandomCrop((oheight, owidth), padding=padding),
            transforms.ToTensor(),
        ])(img)
277
278
        self.assertEqual(result.size(1), oheight)
        self.assertEqual(result.size(2), owidth)
279

280
281
282
283
284
        result = transforms.Compose([
            transforms.ToPILImage(),
            transforms.RandomCrop((height, width)),
            transforms.ToTensor()
        ])(img)
285
286
287
        self.assertEqual(result.size(1), height)
        self.assertEqual(result.size(2), width)
        self.assertTrue(np.allclose(img.numpy(), result.numpy()))
288

289
290
291
292
293
        result = transforms.Compose([
            transforms.ToPILImage(),
            transforms.RandomCrop((height + 1, width + 1), pad_if_needed=True),
            transforms.ToTensor(),
        ])(img)
294
295
        self.assertEqual(result.size(1), height + 1)
        self.assertEqual(result.size(2), width + 1)
296

297
298
299
300
301
302
303
304
305
306
    def test_pad(self):
        height = random.randint(10, 32) * 2
        width = random.randint(10, 32) * 2
        img = torch.ones(3, height, width)
        padding = random.randint(1, 20)
        result = transforms.Compose([
            transforms.ToPILImage(),
            transforms.Pad(padding),
            transforms.ToTensor(),
        ])(img)
307
308
        self.assertEqual(result.size(1), height + 2 * padding)
        self.assertEqual(result.size(2), width + 2 * padding)
Soumith Chintala's avatar
Soumith Chintala committed
309

310
311
312
313
314
315
316
    def test_pad_with_tuple_of_pad_values(self):
        height = random.randint(10, 32) * 2
        width = random.randint(10, 32) * 2
        img = transforms.ToPILImage()(torch.ones(3, height, width))

        padding = tuple([random.randint(1, 20) for _ in range(2)])
        output = transforms.Pad(padding)(img)
317
        self.assertEqual(output.size, (width + padding[0] * 2, height + padding[1] * 2))
318
319
320

        padding = tuple([random.randint(1, 20) for _ in range(4)])
        output = transforms.Pad(padding)(img)
321
322
        self.assertEqual(output.size[0], width + padding[0] + padding[2])
        self.assertEqual(output.size[1], height + padding[1] + padding[3])
323

324
325
326
        # Checking if Padding can be printed as string
        transforms.Pad(padding).__repr__()

327
328
    def test_pad_with_non_constant_padding_modes(self):
        """Unit tests for edge, reflect, symmetric padding"""
vfdev's avatar
vfdev committed
329
        img = torch.zeros(3, 27, 27).byte()
330
331
332
333
334
335
336
337
338
        img[:, :, 0] = 1  # Constant value added to leftmost edge
        img = transforms.ToPILImage()(img)
        img = F.pad(img, 1, (200, 200, 200))

        # pad 3 to all sidess
        edge_padded_img = F.pad(img, 3, padding_mode='edge')
        # First 6 elements of leftmost edge in the middle of the image, values are in order:
        # edge_pad, edge_pad, edge_pad, constant_pad, constant value added to leftmost edge, 0
        edge_middle_slice = np.asarray(edge_padded_img).transpose(2, 0, 1)[0][17][:6]
339
340
        self.assertTrue(np.all(edge_middle_slice == np.asarray([200, 200, 200, 200, 1, 0])))
        self.assertEqual(transforms.ToTensor()(edge_padded_img).size(), (3, 35, 35))
341
342
343
344
345
346

        # Pad 3 to left/right, 2 to top/bottom
        reflect_padded_img = F.pad(img, (3, 2), padding_mode='reflect')
        # First 6 elements of leftmost edge in the middle of the image, values are in order:
        # reflect_pad, reflect_pad, reflect_pad, constant_pad, constant value added to leftmost edge, 0
        reflect_middle_slice = np.asarray(reflect_padded_img).transpose(2, 0, 1)[0][17][:6]
347
348
        self.assertTrue(np.all(reflect_middle_slice == np.asarray([0, 0, 1, 200, 1, 0])))
        self.assertEqual(transforms.ToTensor()(reflect_padded_img).size(), (3, 33, 35))
349
350
351
352
353
354

        # Pad 3 to left, 2 to top, 2 to right, 1 to bottom
        symmetric_padded_img = F.pad(img, (3, 2, 2, 1), padding_mode='symmetric')
        # First 6 elements of leftmost edge in the middle of the image, values are in order:
        # sym_pad, sym_pad, sym_pad, constant_pad, constant value added to leftmost edge, 0
        symmetric_middle_slice = np.asarray(symmetric_padded_img).transpose(2, 0, 1)[0][17][:6]
355
356
        self.assertTrue(np.all(symmetric_middle_slice == np.asarray([0, 1, 200, 200, 1, 0])))
        self.assertEqual(transforms.ToTensor()(symmetric_padded_img).size(), (3, 32, 34))
357

358
    def test_pad_raises_with_invalid_pad_sequence_len(self):
359
360
361
362
363
364
365
366
367
        with self.assertRaises(ValueError):
            transforms.Pad(())

        with self.assertRaises(ValueError):
            transforms.Pad((1, 2, 3))

        with self.assertRaises(ValueError):
            transforms.Pad((1, 2, 3, 4, 5))

Soumith Chintala's avatar
Soumith Chintala committed
368
369
370
371
    def test_lambda(self):
        trans = transforms.Lambda(lambda x: x.add(10))
        x = torch.randn(10)
        y = trans(x)
372
        self.assertTrue(y.equal(torch.add(x, 10)))
Soumith Chintala's avatar
Soumith Chintala committed
373
374
375
376

        trans = transforms.Lambda(lambda x: x.add_(10))
        x = torch.randn(10)
        y = trans(x)
377
        self.assertTrue(y.equal(x))
378

379
380
381
        # Checking if Lambda can be printed as string
        trans.__repr__()

382
    @unittest.skipIf(stats is None, 'scipy.stats not available')
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
    def test_random_apply(self):
        random_state = random.getstate()
        random.seed(42)
        random_apply_transform = transforms.RandomApply(
            [
                transforms.RandomRotation((-45, 45)),
                transforms.RandomHorizontalFlip(),
                transforms.RandomVerticalFlip(),
            ], p=0.75
        )
        img = transforms.ToPILImage()(torch.rand(3, 10, 10))
        num_samples = 250
        num_applies = 0
        for _ in range(num_samples):
            out = random_apply_transform(img)
            if out != img:
                num_applies += 1

        p_value = stats.binom_test(num_applies, num_samples, p=0.75)
        random.setstate(random_state)
403
        self.assertGreater(p_value, 0.0001)
404
405
406
407

        # Checking if RandomApply can be printed as string
        random_apply_transform.__repr__()

408
    @unittest.skipIf(stats is None, 'scipy.stats not available')
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
    def test_random_choice(self):
        random_state = random.getstate()
        random.seed(42)
        random_choice_transform = transforms.RandomChoice(
            [
                transforms.Resize(15),
                transforms.Resize(20),
                transforms.CenterCrop(10)
            ]
        )
        img = transforms.ToPILImage()(torch.rand(3, 25, 25))
        num_samples = 250
        num_resize_15 = 0
        num_resize_20 = 0
        num_crop_10 = 0
        for _ in range(num_samples):
            out = random_choice_transform(img)
            if out.size == (15, 15):
                num_resize_15 += 1
            elif out.size == (20, 20):
                num_resize_20 += 1
            elif out.size == (10, 10):
                num_crop_10 += 1

        p_value = stats.binom_test(num_resize_15, num_samples, p=0.33333)
434
        self.assertGreater(p_value, 0.0001)
435
        p_value = stats.binom_test(num_resize_20, num_samples, p=0.33333)
436
        self.assertGreater(p_value, 0.0001)
437
        p_value = stats.binom_test(num_crop_10, num_samples, p=0.33333)
438
        self.assertGreater(p_value, 0.0001)
439
440
441
442
443

        random.setstate(random_state)
        # Checking if RandomChoice can be printed as string
        random_choice_transform.__repr__()

444
    @unittest.skipIf(stats is None, 'scipy.stats not available')
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
    def test_random_order(self):
        random_state = random.getstate()
        random.seed(42)
        random_order_transform = transforms.RandomOrder(
            [
                transforms.Resize(20),
                transforms.CenterCrop(10)
            ]
        )
        img = transforms.ToPILImage()(torch.rand(3, 25, 25))
        num_samples = 250
        num_normal_order = 0
        resize_crop_out = transforms.CenterCrop(10)(transforms.Resize(20)(img))
        for _ in range(num_samples):
            out = random_order_transform(img)
            if out == resize_crop_out:
                num_normal_order += 1

        p_value = stats.binom_test(num_normal_order, num_samples, p=0.5)
        random.setstate(random_state)
465
        self.assertGreater(p_value, 0.0001)
466
467
468
469

        # Checking if RandomOrder can be printed as string
        random_order_transform.__repr__()

470
    def test_to_tensor(self):
471
        test_channels = [1, 3, 4]
472
473
        height, width = 4, 4
        trans = transforms.ToTensor()
474

475
476
477
478
479
480
481
        with self.assertRaises(TypeError):
            trans(np.random.rand(1, height, width).tolist())

        with self.assertRaises(ValueError):
            trans(np.random.rand(height))
            trans(np.random.rand(1, 1, height, width))

482
483
484
485
        for channels in test_channels:
            input_data = torch.ByteTensor(channels, height, width).random_(0, 255).float().div_(255)
            img = transforms.ToPILImage()(input_data)
            output = trans(img)
486
            self.assertTrue(np.allclose(input_data.numpy(), output.numpy()))
487

488
            ndarray = np.random.randint(low=0, high=255, size=(height, width, channels)).astype(np.uint8)
489
490
            output = trans(ndarray)
            expected_output = ndarray.transpose((2, 0, 1)) / 255.0
491
            self.assertTrue(np.allclose(output.numpy(), expected_output))
492

493
494
495
            ndarray = np.random.rand(height, width, channels).astype(np.float32)
            output = trans(ndarray)
            expected_output = ndarray.transpose((2, 0, 1))
496
            self.assertTrue(np.allclose(output.numpy(), expected_output))
497

498
499
500
501
        # separate test for mode '1' PIL images
        input_data = torch.ByteTensor(1, height, width).bernoulli_()
        img = transforms.ToPILImage()(input_data.mul(255)).convert('1')
        output = trans(img)
502
        self.assertTrue(np.allclose(input_data.numpy(), output.numpy()))
503

504
505
506
507
508
509
510
511
    @unittest.skipIf(accimage is None, 'accimage not available')
    def test_accimage_to_tensor(self):
        trans = transforms.ToTensor()

        expected_output = trans(Image.open(GRACE_HOPPER).convert('RGB'))
        output = trans(accimage.Image(GRACE_HOPPER))

        self.assertEqual(expected_output.size(), output.size())
512
        self.assertTrue(np.allclose(output.numpy(), expected_output.numpy()))
513
514
515
516

    @unittest.skipIf(accimage is None, 'accimage not available')
    def test_accimage_resize(self):
        trans = transforms.Compose([
517
            transforms.Resize(256, interpolation=Image.LINEAR),
518
519
520
            transforms.ToTensor(),
        ])

521
522
523
        # Checking if Compose, Resize and ToTensor can be printed as string
        trans.__repr__()

524
525
526
527
528
529
530
        expected_output = trans(Image.open(GRACE_HOPPER).convert('RGB'))
        output = trans(accimage.Image(GRACE_HOPPER))

        self.assertEqual(expected_output.size(), output.size())
        self.assertLess(np.abs((expected_output - output).mean()), 1e-3)
        self.assertLess((expected_output - output).var(), 1e-5)
        # note the high absolute tolerance
531
        self.assertTrue(np.allclose(output.numpy(), expected_output.numpy(), atol=5e-2))
532
533
534
535
536
537
538
539

    @unittest.skipIf(accimage is None, 'accimage not available')
    def test_accimage_crop(self):
        trans = transforms.Compose([
            transforms.CenterCrop(256),
            transforms.ToTensor(),
        ])

540
541
542
        # Checking if Compose, CenterCrop and ToTensor can be printed as string
        trans.__repr__()

543
544
545
546
        expected_output = trans(Image.open(GRACE_HOPPER).convert('RGB'))
        output = trans(accimage.Image(GRACE_HOPPER))

        self.assertEqual(expected_output.size(), output.size())
547
        self.assertTrue(np.allclose(output.numpy(), expected_output.numpy()))
548

549
    def test_1_channel_tensor_to_pil_image(self):
550
551
        to_tensor = transforms.ToTensor()

552
        img_data_float = torch.Tensor(1, 4, 4).uniform_()
553
554
555
556
        img_data_byte = torch.ByteTensor(1, 4, 4).random_(0, 255)
        img_data_short = torch.ShortTensor(1, 4, 4).random_()
        img_data_int = torch.IntTensor(1, 4, 4).random_()

557
558
559
560
561
562
563
564
565
566
        inputs = [img_data_float, img_data_byte, img_data_short, img_data_int]
        expected_outputs = [img_data_float.mul(255).int().float().div(255).numpy(),
                            img_data_byte.float().div(255.0).numpy(),
                            img_data_short.numpy(),
                            img_data_int.numpy()]
        expected_modes = ['L', 'L', 'I;16', 'I']

        for img_data, expected_output, mode in zip(inputs, expected_outputs, expected_modes):
            for transform in [transforms.ToPILImage(), transforms.ToPILImage(mode=mode)]:
                img = transform(img_data)
567
568
                self.assertEqual(img.mode, mode)
                self.assertTrue(np.allclose(expected_output, to_tensor(img).numpy()))
569
570
        # 'F' mode for torch.FloatTensor
        img_F_mode = transforms.ToPILImage(mode='F')(img_data_float)
571
572
573
        self.assertEqual(img_F_mode.mode, 'F')
        self.assertTrue(np.allclose(np.array(Image.fromarray(img_data_float.squeeze(0).numpy(), mode='F')),
                                    np.array(img_F_mode)))
574
575
576
577
578
579
580
581
582
583
584
585

    def test_1_channel_ndarray_to_pil_image(self):
        img_data_float = torch.Tensor(4, 4, 1).uniform_().numpy()
        img_data_byte = torch.ByteTensor(4, 4, 1).random_(0, 255).numpy()
        img_data_short = torch.ShortTensor(4, 4, 1).random_().numpy()
        img_data_int = torch.IntTensor(4, 4, 1).random_().numpy()

        inputs = [img_data_float, img_data_byte, img_data_short, img_data_int]
        expected_modes = ['F', 'L', 'I;16', 'I']
        for img_data, mode in zip(inputs, expected_modes):
            for transform in [transforms.ToPILImage(), transforms.ToPILImage(mode=mode)]:
                img = transform(img_data)
586
587
                self.assertEqual(img.mode, mode)
                self.assertTrue(np.allclose(img_data[:, :, 0], img))
588

surgan12's avatar
surgan12 committed
589
590
591
592
    def test_2_channel_ndarray_to_pil_image(self):
        def verify_img_data(img_data, mode):
            if mode is None:
                img = transforms.ToPILImage()(img_data)
593
                self.assertEqual(img.mode, 'LA')  # default should assume LA
surgan12's avatar
surgan12 committed
594
595
            else:
                img = transforms.ToPILImage(mode=mode)(img_data)
596
                self.assertEqual(img.mode, mode)
surgan12's avatar
surgan12 committed
597
598
            split = img.split()
            for i in range(2):
599
                self.assertTrue(np.allclose(img_data[:, :, i], split[i]))
surgan12's avatar
surgan12 committed
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616

        img_data = torch.ByteTensor(4, 4, 2).random_(0, 255).numpy()
        for mode in [None, 'LA']:
            verify_img_data(img_data, mode)

        transforms.ToPILImage().__repr__()

        with self.assertRaises(ValueError):
            # should raise if we try a mode for 4 or 1 or 3 channel images
            transforms.ToPILImage(mode='RGBA')(img_data)
            transforms.ToPILImage(mode='P')(img_data)
            transforms.ToPILImage(mode='RGB')(img_data)

    def test_2_channel_tensor_to_pil_image(self):
        def verify_img_data(img_data, expected_output, mode):
            if mode is None:
                img = transforms.ToPILImage()(img_data)
617
                self.assertEqual(img.mode, 'LA')  # default should assume LA
surgan12's avatar
surgan12 committed
618
619
            else:
                img = transforms.ToPILImage(mode=mode)(img_data)
620
                self.assertEqual(img.mode, mode)
surgan12's avatar
surgan12 committed
621
622
            split = img.split()
            for i in range(2):
623
                self.assertTrue(np.allclose(expected_output[i].numpy(), F.to_tensor(split[i]).numpy()))
surgan12's avatar
surgan12 committed
624
625
626
627
628
629
630
631
632
633
634
635

        img_data = torch.Tensor(2, 4, 4).uniform_()
        expected_output = img_data.mul(255).int().float().div(255)
        for mode in [None, 'LA']:
            verify_img_data(img_data, expected_output, mode=mode)

        with self.assertRaises(ValueError):
            # should raise if we try a mode for 4 or 1 or 3 channel images
            transforms.ToPILImage(mode='RGBA')(img_data)
            transforms.ToPILImage(mode='P')(img_data)
            transforms.ToPILImage(mode='RGB')(img_data)

636
637
638
639
    def test_3_channel_tensor_to_pil_image(self):
        def verify_img_data(img_data, expected_output, mode):
            if mode is None:
                img = transforms.ToPILImage()(img_data)
640
                self.assertEqual(img.mode, 'RGB')  # default should assume RGB
641
642
            else:
                img = transforms.ToPILImage(mode=mode)(img_data)
643
                self.assertEqual(img.mode, mode)
644
645
            split = img.split()
            for i in range(3):
646
                self.assertTrue(np.allclose(expected_output[i].numpy(), F.to_tensor(split[i]).numpy()))
647

648
649
650
651
        img_data = torch.Tensor(3, 4, 4).uniform_()
        expected_output = img_data.mul(255).int().float().div(255)
        for mode in [None, 'RGB', 'HSV', 'YCbCr']:
            verify_img_data(img_data, expected_output, mode=mode)
652

653
        with self.assertRaises(ValueError):
surgan12's avatar
surgan12 committed
654
            # should raise if we try a mode for 4 or 1 or 2 channel images
655
656
            transforms.ToPILImage(mode='RGBA')(img_data)
            transforms.ToPILImage(mode='P')(img_data)
surgan12's avatar
surgan12 committed
657
            transforms.ToPILImage(mode='LA')(img_data)
658

Varun Agrawal's avatar
Varun Agrawal committed
659
660
661
        with self.assertRaises(ValueError):
            transforms.ToPILImage()(torch.Tensor(1, 3, 4, 4).uniform_())

662
663
664
665
    def test_3_channel_ndarray_to_pil_image(self):
        def verify_img_data(img_data, mode):
            if mode is None:
                img = transforms.ToPILImage()(img_data)
666
                self.assertEqual(img.mode, 'RGB')  # default should assume RGB
667
668
            else:
                img = transforms.ToPILImage(mode=mode)(img_data)
669
                self.assertEqual(img.mode, mode)
670
671
            split = img.split()
            for i in range(3):
672
                self.assertTrue(np.allclose(img_data[:, :, i], split[i]))
673

674
675
676
677
        img_data = torch.ByteTensor(4, 4, 3).random_(0, 255).numpy()
        for mode in [None, 'RGB', 'HSV', 'YCbCr']:
            verify_img_data(img_data, mode)

678
679
680
        # Checking if ToPILImage can be printed as string
        transforms.ToPILImage().__repr__()

681
        with self.assertRaises(ValueError):
surgan12's avatar
surgan12 committed
682
            # should raise if we try a mode for 4 or 1 or 2 channel images
683
684
            transforms.ToPILImage(mode='RGBA')(img_data)
            transforms.ToPILImage(mode='P')(img_data)
surgan12's avatar
surgan12 committed
685
            transforms.ToPILImage(mode='LA')(img_data)
686
687
688
689
690

    def test_4_channel_tensor_to_pil_image(self):
        def verify_img_data(img_data, expected_output, mode):
            if mode is None:
                img = transforms.ToPILImage()(img_data)
691
                self.assertEqual(img.mode, 'RGBA')  # default should assume RGBA
692
693
            else:
                img = transforms.ToPILImage(mode=mode)(img_data)
694
                self.assertEqual(img.mode, mode)
695
696
697

            split = img.split()
            for i in range(4):
698
                self.assertTrue(np.allclose(expected_output[i].numpy(), F.to_tensor(split[i]).numpy()))
699

700
        img_data = torch.Tensor(4, 4, 4).uniform_()
701
        expected_output = img_data.mul(255).int().float().div(255)
surgan12's avatar
surgan12 committed
702
        for mode in [None, 'RGBA', 'CMYK', 'RGBX']:
703
            verify_img_data(img_data, expected_output, mode)
704

705
        with self.assertRaises(ValueError):
surgan12's avatar
surgan12 committed
706
            # should raise if we try a mode for 3 or 1 or 2 channel images
707
708
            transforms.ToPILImage(mode='RGB')(img_data)
            transforms.ToPILImage(mode='P')(img_data)
surgan12's avatar
surgan12 committed
709
            transforms.ToPILImage(mode='LA')(img_data)
710
711
712
713
714

    def test_4_channel_ndarray_to_pil_image(self):
        def verify_img_data(img_data, mode):
            if mode is None:
                img = transforms.ToPILImage()(img_data)
715
                self.assertEqual(img.mode, 'RGBA')  # default should assume RGBA
716
717
            else:
                img = transforms.ToPILImage(mode=mode)(img_data)
718
                self.assertEqual(img.mode, mode)
719
720
            split = img.split()
            for i in range(4):
721
                self.assertTrue(np.allclose(img_data[:, :, i], split[i]))
722

723
        img_data = torch.ByteTensor(4, 4, 4).random_(0, 255).numpy()
surgan12's avatar
surgan12 committed
724
        for mode in [None, 'RGBA', 'CMYK', 'RGBX']:
725
            verify_img_data(img_data, mode)
726

727
        with self.assertRaises(ValueError):
surgan12's avatar
surgan12 committed
728
            # should raise if we try a mode for 3 or 1 or 2 channel images
729
730
            transforms.ToPILImage(mode='RGB')(img_data)
            transforms.ToPILImage(mode='P')(img_data)
surgan12's avatar
surgan12 committed
731
            transforms.ToPILImage(mode='LA')(img_data)
732

Varun Agrawal's avatar
Varun Agrawal committed
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
    def test_2d_tensor_to_pil_image(self):
        to_tensor = transforms.ToTensor()

        img_data_float = torch.Tensor(4, 4).uniform_()
        img_data_byte = torch.ByteTensor(4, 4).random_(0, 255)
        img_data_short = torch.ShortTensor(4, 4).random_()
        img_data_int = torch.IntTensor(4, 4).random_()

        inputs = [img_data_float, img_data_byte, img_data_short, img_data_int]
        expected_outputs = [img_data_float.mul(255).int().float().div(255).numpy(),
                            img_data_byte.float().div(255.0).numpy(),
                            img_data_short.numpy(),
                            img_data_int.numpy()]
        expected_modes = ['L', 'L', 'I;16', 'I']

        for img_data, expected_output, mode in zip(inputs, expected_outputs, expected_modes):
            for transform in [transforms.ToPILImage(), transforms.ToPILImage(mode=mode)]:
                img = transform(img_data)
751
752
                self.assertEqual(img.mode, mode)
                self.assertTrue(np.allclose(expected_output, to_tensor(img).numpy()))
Varun Agrawal's avatar
Varun Agrawal committed
753
754
755
756
757
758
759
760
761
762
763
764

    def test_2d_ndarray_to_pil_image(self):
        img_data_float = torch.Tensor(4, 4).uniform_().numpy()
        img_data_byte = torch.ByteTensor(4, 4).random_(0, 255).numpy()
        img_data_short = torch.ShortTensor(4, 4).random_().numpy()
        img_data_int = torch.IntTensor(4, 4).random_().numpy()

        inputs = [img_data_float, img_data_byte, img_data_short, img_data_int]
        expected_modes = ['F', 'L', 'I;16', 'I']
        for img_data, mode in zip(inputs, expected_modes):
            for transform in [transforms.ToPILImage(), transforms.ToPILImage(mode=mode)]:
                img = transform(img_data)
765
766
                self.assertEqual(img.mode, mode)
                self.assertTrue(np.allclose(img_data, img))
Varun Agrawal's avatar
Varun Agrawal committed
767
768
769
770
771

    def test_tensor_bad_types_to_pil_image(self):
        with self.assertRaises(ValueError):
            transforms.ToPILImage()(torch.ones(1, 3, 4, 4))

772
    def test_ndarray_bad_types_to_pil_image(self):
773
        trans = transforms.ToPILImage()
774
        with self.assertRaises(TypeError):
775
776
777
778
779
            trans(np.ones([4, 4, 1], np.int64))
            trans(np.ones([4, 4, 1], np.uint16))
            trans(np.ones([4, 4, 1], np.uint32))
            trans(np.ones([4, 4, 1], np.float64))

Varun Agrawal's avatar
Varun Agrawal committed
780
781
782
        with self.assertRaises(ValueError):
            transforms.ToPILImage()(np.ones([1, 4, 4, 3]))

783
784
    @unittest.skipIf(stats is None, 'scipy.stats not available')
    def test_random_vertical_flip(self):
785
786
        random_state = random.getstate()
        random.seed(42)
787
788
789
        img = transforms.ToPILImage()(torch.rand(3, 10, 10))
        vimg = img.transpose(Image.FLIP_TOP_BOTTOM)

790
        num_samples = 250
791
        num_vertical = 0
792
        for _ in range(num_samples):
793
794
795
796
            out = transforms.RandomVerticalFlip()(img)
            if out == vimg:
                num_vertical += 1

797
798
        p_value = stats.binom_test(num_vertical, num_samples, p=0.5)
        random.setstate(random_state)
799
        self.assertGreater(p_value, 0.0001)
800

801
802
803
804
805
806
807
808
809
        num_samples = 250
        num_vertical = 0
        for _ in range(num_samples):
            out = transforms.RandomVerticalFlip(p=0.7)(img)
            if out == vimg:
                num_vertical += 1

        p_value = stats.binom_test(num_vertical, num_samples, p=0.7)
        random.setstate(random_state)
810
        self.assertGreater(p_value, 0.0001)
811

812
813
814
        # Checking if RandomVerticalFlip can be printed as string
        transforms.RandomVerticalFlip().__repr__()

815
816
    @unittest.skipIf(stats is None, 'scipy.stats not available')
    def test_random_horizontal_flip(self):
817
818
        random_state = random.getstate()
        random.seed(42)
819
820
821
        img = transforms.ToPILImage()(torch.rand(3, 10, 10))
        himg = img.transpose(Image.FLIP_LEFT_RIGHT)

822
        num_samples = 250
823
        num_horizontal = 0
824
        for _ in range(num_samples):
825
826
827
828
            out = transforms.RandomHorizontalFlip()(img)
            if out == himg:
                num_horizontal += 1

829
830
        p_value = stats.binom_test(num_horizontal, num_samples, p=0.5)
        random.setstate(random_state)
831
        self.assertGreater(p_value, 0.0001)
832

833
834
835
836
837
838
839
840
841
        num_samples = 250
        num_horizontal = 0
        for _ in range(num_samples):
            out = transforms.RandomHorizontalFlip(p=0.7)(img)
            if out == himg:
                num_horizontal += 1

        p_value = stats.binom_test(num_horizontal, num_samples, p=0.7)
        random.setstate(random_state)
842
        self.assertGreater(p_value, 0.0001)
843

844
845
846
        # Checking if RandomHorizontalFlip can be printed as string
        transforms.RandomHorizontalFlip().__repr__()

847
    @unittest.skipIf(stats is None, 'scipy.stats is not available')
848
849
850
851
852
853
854
855
856
857
858
859
    def test_normalize(self):
        def samples_from_standard_normal(tensor):
            p_value = stats.kstest(list(tensor.view(-1)), 'norm', args=(0, 1)).pvalue
            return p_value > 0.0001

        random_state = random.getstate()
        random.seed(42)
        for channels in [1, 3]:
            img = torch.rand(channels, 10, 10)
            mean = [img[c].mean() for c in range(channels)]
            std = [img[c].std() for c in range(channels)]
            normalized = transforms.Normalize(mean, std)(img)
860
            self.assertTrue(samples_from_standard_normal(normalized))
861
862
        random.setstate(random_state)

863
864
865
        # Checking if Normalize can be printed as string
        transforms.Normalize(mean, std).__repr__()

866
867
868
        # Checking the optional in-place behaviour
        tensor = torch.rand((1, 16, 16))
        tensor_inplace = transforms.Normalize((0.5,), (0.5,), inplace=True)(tensor)
869
        self.assertTrue(torch.equal(tensor, tensor_inplace))
870

871
872
873
874
875
876
877
878
879
    def test_normalize_different_dtype(self):
        for dtype1 in [torch.float32, torch.float64]:
            img = torch.rand(3, 10, 10, dtype=dtype1)
            for dtype2 in [torch.int64, torch.float32, torch.float64]:
                mean = torch.tensor([1, 2, 3], dtype=dtype2)
                std = torch.tensor([1, 2, 1], dtype=dtype2)
                # checks that it doesn't crash
                transforms.functional.normalize(img, mean, std)

880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
    def test_normalize_3d_tensor(self):
        torch.manual_seed(28)
        n_channels = 3
        img_size = 10
        mean = torch.rand(n_channels)
        std = torch.rand(n_channels)
        img = torch.rand(n_channels, img_size, img_size)
        target = F.normalize(img, mean, std).numpy()

        mean_unsqueezed = mean.view(-1, 1, 1)
        std_unsqueezed = std.view(-1, 1, 1)
        result1 = F.normalize(img, mean_unsqueezed, std_unsqueezed)
        result2 = F.normalize(img,
                              mean_unsqueezed.repeat(1, img_size, img_size),
                              std_unsqueezed.repeat(1, img_size, img_size))
        assert_array_almost_equal(target, result1.numpy())
        assert_array_almost_equal(target, result2.numpy())

898
899
900
901
902
903
904
    def test_adjust_brightness(self):
        x_shape = [2, 2, 3]
        x_data = [0, 5, 13, 54, 135, 226, 37, 8, 234, 90, 255, 1]
        x_np = np.array(x_data, dtype=np.uint8).reshape(x_shape)
        x_pil = Image.fromarray(x_np, mode='RGB')

        # test 0
905
        y_pil = F.adjust_brightness(x_pil, 1)
906
        y_np = np.array(y_pil)
907
        self.assertTrue(np.allclose(y_np, x_np))
908
909

        # test 1
910
        y_pil = F.adjust_brightness(x_pil, 0.5)
911
912
913
        y_np = np.array(y_pil)
        y_ans = [0, 2, 6, 27, 67, 113, 18, 4, 117, 45, 127, 0]
        y_ans = np.array(y_ans, dtype=np.uint8).reshape(x_shape)
914
        self.assertTrue(np.allclose(y_np, y_ans))
915
916

        # test 2
917
        y_pil = F.adjust_brightness(x_pil, 2)
918
919
920
        y_np = np.array(y_pil)
        y_ans = [0, 10, 26, 108, 255, 255, 74, 16, 255, 180, 255, 2]
        y_ans = np.array(y_ans, dtype=np.uint8).reshape(x_shape)
921
        self.assertTrue(np.allclose(y_np, y_ans))
922
923
924
925
926
927
928
929

    def test_adjust_contrast(self):
        x_shape = [2, 2, 3]
        x_data = [0, 5, 13, 54, 135, 226, 37, 8, 234, 90, 255, 1]
        x_np = np.array(x_data, dtype=np.uint8).reshape(x_shape)
        x_pil = Image.fromarray(x_np, mode='RGB')

        # test 0
930
        y_pil = F.adjust_contrast(x_pil, 1)
931
        y_np = np.array(y_pil)
932
        self.assertTrue(np.allclose(y_np, x_np))
933
934

        # test 1
935
        y_pil = F.adjust_contrast(x_pil, 0.5)
936
937
938
        y_np = np.array(y_pil)
        y_ans = [43, 45, 49, 70, 110, 156, 61, 47, 160, 88, 170, 43]
        y_ans = np.array(y_ans, dtype=np.uint8).reshape(x_shape)
939
        self.assertTrue(np.allclose(y_np, y_ans))
940
941

        # test 2
942
        y_pil = F.adjust_contrast(x_pil, 2)
943
944
945
        y_np = np.array(y_pil)
        y_ans = [0, 0, 0, 22, 184, 255, 0, 0, 255, 94, 255, 0]
        y_ans = np.array(y_ans, dtype=np.uint8).reshape(x_shape)
946
        self.assertTrue(np.allclose(y_np, y_ans))
947

Francisco Massa's avatar
Francisco Massa committed
948
    @unittest.skipIf(Image.__version__ >= '7', "Temporarily disabled")
949
950
951
952
953
954
955
    def test_adjust_saturation(self):
        x_shape = [2, 2, 3]
        x_data = [0, 5, 13, 54, 135, 226, 37, 8, 234, 90, 255, 1]
        x_np = np.array(x_data, dtype=np.uint8).reshape(x_shape)
        x_pil = Image.fromarray(x_np, mode='RGB')

        # test 0
956
        y_pil = F.adjust_saturation(x_pil, 1)
957
        y_np = np.array(y_pil)
958
        self.assertTrue(np.allclose(y_np, x_np))
959
960

        # test 1
961
        y_pil = F.adjust_saturation(x_pil, 0.5)
962
963
964
        y_np = np.array(y_pil)
        y_ans = [2, 4, 8, 87, 128, 173, 39, 25, 138, 133, 215, 88]
        y_ans = np.array(y_ans, dtype=np.uint8).reshape(x_shape)
965
        self.assertTrue(np.allclose(y_np, y_ans))
966
967

        # test 2
968
        y_pil = F.adjust_saturation(x_pil, 2)
969
970
971
        y_np = np.array(y_pil)
        y_ans = [0, 6, 22, 0, 149, 255, 32, 0, 255, 4, 255, 0]
        y_ans = np.array(y_ans, dtype=np.uint8).reshape(x_shape)
972
        self.assertTrue(np.allclose(y_np, y_ans))
973
974
975
976
977
978
979
980

    def test_adjust_hue(self):
        x_shape = [2, 2, 3]
        x_data = [0, 5, 13, 54, 135, 226, 37, 8, 234, 90, 255, 1]
        x_np = np.array(x_data, dtype=np.uint8).reshape(x_shape)
        x_pil = Image.fromarray(x_np, mode='RGB')

        with self.assertRaises(ValueError):
981
982
            F.adjust_hue(x_pil, -0.7)
            F.adjust_hue(x_pil, 1)
983
984
985

        # test 0: almost same as x_data but not exact.
        # probably because hsv <-> rgb floating point ops
986
        y_pil = F.adjust_hue(x_pil, 0)
987
988
989
        y_np = np.array(y_pil)
        y_ans = [0, 5, 13, 54, 139, 226, 35, 8, 234, 91, 255, 1]
        y_ans = np.array(y_ans, dtype=np.uint8).reshape(x_shape)
990
        self.assertTrue(np.allclose(y_np, y_ans))
991
992

        # test 1
993
        y_pil = F.adjust_hue(x_pil, 0.25)
994
995
996
        y_np = np.array(y_pil)
        y_ans = [13, 0, 12, 224, 54, 226, 234, 8, 99, 1, 222, 255]
        y_ans = np.array(y_ans, dtype=np.uint8).reshape(x_shape)
997
        self.assertTrue(np.allclose(y_np, y_ans))
998
999

        # test 2
1000
        y_pil = F.adjust_hue(x_pil, -0.25)
1001
1002
1003
        y_np = np.array(y_pil)
        y_ans = [0, 13, 2, 54, 226, 58, 8, 234, 152, 255, 43, 1]
        y_ans = np.array(y_ans, dtype=np.uint8).reshape(x_shape)
1004
        self.assertTrue(np.allclose(y_np, y_ans))
1005
1006
1007
1008
1009
1010
1011
1012

    def test_adjust_gamma(self):
        x_shape = [2, 2, 3]
        x_data = [0, 5, 13, 54, 135, 226, 37, 8, 234, 90, 255, 1]
        x_np = np.array(x_data, dtype=np.uint8).reshape(x_shape)
        x_pil = Image.fromarray(x_np, mode='RGB')

        # test 0
1013
        y_pil = F.adjust_gamma(x_pil, 1)
1014
        y_np = np.array(y_pil)
1015
        self.assertTrue(np.allclose(y_np, x_np))
1016
1017

        # test 1
1018
        y_pil = F.adjust_gamma(x_pil, 0.5)
1019
1020
1021
        y_np = np.array(y_pil)
        y_ans = [0, 35, 57, 117, 185, 240, 97, 45, 244, 151, 255, 15]
        y_ans = np.array(y_ans, dtype=np.uint8).reshape(x_shape)
1022
        self.assertTrue(np.allclose(y_np, y_ans))
1023
1024

        # test 2
1025
        y_pil = F.adjust_gamma(x_pil, 2)
1026
1027
1028
        y_np = np.array(y_pil)
        y_ans = [0, 0, 0, 11, 71, 200, 5, 0, 214, 31, 255, 0]
        y_ans = np.array(y_ans, dtype=np.uint8).reshape(x_shape)
1029
        self.assertTrue(np.allclose(y_np, y_ans))
1030
1031
1032
1033
1034
1035
1036
1037

    def test_adjusts_L_mode(self):
        x_shape = [2, 2, 3]
        x_data = [0, 5, 13, 54, 135, 226, 37, 8, 234, 90, 255, 1]
        x_np = np.array(x_data, dtype=np.uint8).reshape(x_shape)
        x_rgb = Image.fromarray(x_np, mode='RGB')

        x_l = x_rgb.convert('L')
1038
1039
1040
1041
1042
        self.assertEqual(F.adjust_brightness(x_l, 2).mode, 'L')
        self.assertEqual(F.adjust_saturation(x_l, 2).mode, 'L')
        self.assertEqual(F.adjust_contrast(x_l, 2).mode, 'L')
        self.assertEqual(F.adjust_hue(x_l, 0.4).mode, 'L')
        self.assertEqual(F.adjust_gamma(x_l, 0.5).mode, 'L')
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054

    def test_color_jitter(self):
        color_jitter = transforms.ColorJitter(2, 2, 2, 0.1)

        x_shape = [2, 2, 3]
        x_data = [0, 5, 13, 54, 135, 226, 37, 8, 234, 90, 255, 1]
        x_np = np.array(x_data, dtype=np.uint8).reshape(x_shape)
        x_pil = Image.fromarray(x_np, mode='RGB')
        x_pil_2 = x_pil.convert('L')

        for i in range(10):
            y_pil = color_jitter(x_pil)
1055
            self.assertEqual(y_pil.mode, x_pil.mode)
1056
1057

            y_pil_2 = color_jitter(x_pil_2)
1058
            self.assertEqual(y_pil_2.mode, x_pil_2.mode)
1059

1060
1061
1062
        # Checking if ColorJitter can be printed as string
        color_jitter.__repr__()

1063
    def test_linear_transformation(self):
ekka's avatar
ekka committed
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
        num_samples = 1000
        x = torch.randn(num_samples, 3, 10, 10)
        flat_x = x.view(x.size(0), x.size(1) * x.size(2) * x.size(3))
        # compute principal components
        sigma = torch.mm(flat_x.t(), flat_x) / flat_x.size(0)
        u, s, _ = np.linalg.svd(sigma.numpy())
        zca_epsilon = 1e-10  # avoid division by 0
        d = torch.Tensor(np.diag(1. / np.sqrt(s + zca_epsilon)))
        u = torch.Tensor(u)
        principal_components = torch.mm(torch.mm(u, d), u.t())
        mean_vector = (torch.sum(flat_x, dim=0) / flat_x.size(0))
        # initialize whitening matrix
1076
        whitening = transforms.LinearTransformation(principal_components, mean_vector)
ekka's avatar
ekka committed
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
        # estimate covariance and mean using weak law of large number
        num_features = flat_x.size(1)
        cov = 0.0
        mean = 0.0
        for i in x:
            xwhite = whitening(i)
            xwhite = xwhite.view(1, -1).numpy()
            cov += np.dot(xwhite, xwhite.T) / num_features
            mean += np.sum(xwhite) / num_features
        # if rtol for std = 1e-3 then rtol for cov = 2e-3 as std**2 = cov
1087
1088
1089
1090
        self.assertTrue(np.allclose(cov / num_samples, np.identity(1), rtol=2e-3),
                        "cov not close to 1")
        self.assertTrue(np.allclose(mean / num_samples, 0, rtol=1e-3),
                        "mean not close to 0")
ekka's avatar
ekka committed
1091

1092
        # Checking if LinearTransformation can be printed as string
ekka's avatar
ekka committed
1093
1094
        whitening.__repr__()

1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
    def test_rotate(self):
        x = np.zeros((100, 100, 3), dtype=np.uint8)
        x[40, 40] = [255, 255, 255]

        with self.assertRaises(TypeError):
            F.rotate(x, 10)

        img = F.to_pil_image(x)

        result = F.rotate(img, 45)
1105
        self.assertEqual(result.size, (100, 100))
1106
        r, c, ch = np.where(result)
1107
1108
1109
        self.assertTrue(all(x in r for x in [49, 50]))
        self.assertTrue(all(x in c for x in [36]))
        self.assertTrue(all(x in ch for x in [0, 1, 2]))
1110
1111

        result = F.rotate(img, 45, expand=True)
1112
        self.assertEqual(result.size, (142, 142))
1113
        r, c, ch = np.where(result)
1114
1115
1116
        self.assertTrue(all(x in r for x in [70, 71]))
        self.assertTrue(all(x in c for x in [57]))
        self.assertTrue(all(x in ch for x in [0, 1, 2]))
1117
1118

        result = F.rotate(img, 45, center=(40, 40))
1119
        self.assertEqual(result.size, (100, 100))
1120
        r, c, ch = np.where(result)
1121
1122
1123
        self.assertTrue(all(x in r for x in [40]))
        self.assertTrue(all(x in c for x in [40]))
        self.assertTrue(all(x in ch for x in [0, 1, 2]))
1124
1125
1126
1127

        result_a = F.rotate(img, 90)
        result_b = F.rotate(img, -270)

1128
        self.assertTrue(np.all(np.array(result_a) == np.array(result_b)))
1129

Philip Meier's avatar
Philip Meier committed
1130
1131
1132
    def test_rotate_fill(self):
        img = F.to_pil_image(np.ones((100, 100, 3), dtype=np.uint8) * 255, "RGB")

1133
        modes = ("L", "RGB", "F")
Philip Meier's avatar
Philip Meier committed
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
        nums_bands = [len(mode) for mode in modes]
        fill = 127

        for mode, num_bands in zip(modes, nums_bands):
            img_conv = img.convert(mode)
            img_rot = F.rotate(img_conv, 45.0, fill=fill)
            pixel = img_rot.getpixel((0, 0))

            if not isinstance(pixel, tuple):
                pixel = (pixel,)
            self.assertTupleEqual(pixel, tuple([fill] * num_bands))

            for wrong_num_bands in set(nums_bands) - {num_bands}:
                with self.assertRaises(ValueError):
                    F.rotate(img_conv, 45.0, fill=tuple([fill] * wrong_num_bands))

1150
    def test_affine(self):
Francisco Massa's avatar
Francisco Massa committed
1151
        input_img = np.zeros((40, 40, 3), dtype=np.uint8)
1152
        pts = []
Francisco Massa's avatar
Francisco Massa committed
1153
1154
        cnt = [20, 20]
        for pt in [(16, 16), (20, 16), (20, 20)]:
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
            for i in range(-5, 5):
                for j in range(-5, 5):
                    input_img[pt[0] + i, pt[1] + j, :] = [255, 155, 55]
                    pts.append((pt[0] + i, pt[1] + j))
        pts = list(set(pts))

        with self.assertRaises(TypeError):
            F.affine(input_img, 10)

        pil_img = F.to_pil_image(input_img)

        def _to_3x3_inv(inv_result_matrix):
            result_matrix = np.zeros((3, 3))
            result_matrix[:2, :] = np.array(inv_result_matrix).reshape((2, 3))
            result_matrix[2, 2] = 1
            return np.linalg.inv(result_matrix)

        def _test_transformation(a, t, s, sh):
            a_rad = math.radians(a)
ptrblck's avatar
ptrblck committed
1174
            s_rad = [math.radians(sh_) for sh_ in sh]
1175
1176
1177
1178
1179
            cx, cy = cnt
            tx, ty = t
            sx, sy = s_rad
            rot = a_rad

1180
            # 1) Check transformation matrix:
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
            C = np.array([[1, 0, cx],
                          [0, 1, cy],
                          [0, 0, 1]])
            T = np.array([[1, 0, tx],
                          [0, 1, ty],
                          [0, 0, 1]])
            Cinv = np.linalg.inv(C)

            RS = np.array(
                [[s * math.cos(rot), -s * math.sin(rot), 0],
                 [s * math.sin(rot), s * math.cos(rot), 0],
                 [0, 0, 1]])

            SHx = np.array([[1, -math.tan(sx), 0],
                            [0, 1, 0],
                            [0, 0, 1]])

            SHy = np.array([[1, 0, 0],
                            [-math.tan(sy), 1, 0],
                            [0, 0, 1]])

            RSS = np.matmul(RS, np.matmul(SHy, SHx))

            true_matrix = np.matmul(T, np.matmul(C, np.matmul(RSS, Cinv)))

1206
1207
            result_matrix = _to_3x3_inv(F._get_inverse_affine_matrix(center=cnt, angle=a,
                                                                     translate=t, scale=s, shear=sh))
1208
            self.assertLess(np.sum(np.abs(true_matrix - result_matrix)), 1e-10)
1209
            # 2) Perform inverse mapping:
Francisco Massa's avatar
Francisco Massa committed
1210
            true_result = np.zeros((40, 40, 3), dtype=np.uint8)
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
            inv_true_matrix = np.linalg.inv(true_matrix)
            for y in range(true_result.shape[0]):
                for x in range(true_result.shape[1]):
                    res = np.dot(inv_true_matrix, [x, y, 1])
                    _x = int(res[0] + 0.5)
                    _y = int(res[1] + 0.5)
                    if 0 <= _x < input_img.shape[1] and 0 <= _y < input_img.shape[0]:
                        true_result[y, x, :] = input_img[_y, _x, :]

            result = F.affine(pil_img, angle=a, translate=t, scale=s, shear=sh)
1221
            self.assertEqual(result.size, pil_img.size)
1222
1223
1224
1225
            # Compute number of different pixels:
            np_result = np.array(result)
            n_diff_pixels = np.sum(np_result != true_result) / 3
            # Accept 3 wrong pixels
1226
1227
1228
            self.assertLess(n_diff_pixels, 3,
                            "a={}, t={}, s={}, sh={}\n".format(a, t, s, sh) +
                            "n diff pixels={}\n".format(np.sum(np.array(result)[:, :, 0] != true_result[:, :, 0])))
1229
1230
1231

        # Test rotation
        a = 45
ptrblck's avatar
ptrblck committed
1232
        _test_transformation(a=a, t=(0, 0), s=1.0, sh=(0.0, 0.0))
1233
1234
1235

        # Test translation
        t = [10, 15]
ptrblck's avatar
ptrblck committed
1236
        _test_transformation(a=0.0, t=t, s=1.0, sh=(0.0, 0.0))
1237
1238
1239

        # Test scale
        s = 1.2
ptrblck's avatar
ptrblck committed
1240
        _test_transformation(a=0.0, t=(0.0, 0.0), s=s, sh=(0.0, 0.0))
1241
1242

        # Test shear
ptrblck's avatar
ptrblck committed
1243
        sh = [45.0, 25.0]
1244
1245
1246
1247
1248
1249
1250
        _test_transformation(a=0.0, t=(0.0, 0.0), s=1.0, sh=sh)

        # Test rotation, scale, translation, shear
        for a in range(-90, 90, 25):
            for t1 in range(-10, 10, 5):
                for s in [0.75, 0.98, 1.0, 1.1, 1.2]:
                    for sh in range(-15, 15, 5):
ptrblck's avatar
ptrblck committed
1251
                        _test_transformation(a=a, t=(t1, t1), s=s, sh=(sh, sh))
1252

1253
1254
1255
1256
1257
1258
1259
1260
1261
    def test_random_rotation(self):

        with self.assertRaises(ValueError):
            transforms.RandomRotation(-0.7)
            transforms.RandomRotation([-0.7])
            transforms.RandomRotation([-0.7, 0, 0.7])

        t = transforms.RandomRotation(10)
        angle = t.get_params(t.degrees)
1262
        self.assertTrue(angle > -10 and angle < 10)
1263
1264
1265

        t = transforms.RandomRotation((-10, 10))
        angle = t.get_params(t.degrees)
1266
        self.assertTrue(angle > -10 and angle < 10)
1267

1268
1269
1270
        # Checking if RandomRotation can be printed as string
        t.__repr__()

1271
1272
1273
1274
1275
1276
1277
1278
1279
1280
1281
1282
1283
1284
1285
1286
1287
1288
1289
    def test_random_affine(self):

        with self.assertRaises(ValueError):
            transforms.RandomAffine(-0.7)
            transforms.RandomAffine([-0.7])
            transforms.RandomAffine([-0.7, 0, 0.7])

            transforms.RandomAffine([-90, 90], translate=2.0)
            transforms.RandomAffine([-90, 90], translate=[-1.0, 1.0])
            transforms.RandomAffine([-90, 90], translate=[-1.0, 0.0, 1.0])

            transforms.RandomAffine([-90, 90], translate=[0.2, 0.2], scale=[0.0])
            transforms.RandomAffine([-90, 90], translate=[0.2, 0.2], scale=[-1.0, 1.0])
            transforms.RandomAffine([-90, 90], translate=[0.2, 0.2], scale=[0.5, -0.5])
            transforms.RandomAffine([-90, 90], translate=[0.2, 0.2], scale=[0.5, 3.0, -0.5])

            transforms.RandomAffine([-90, 90], translate=[0.2, 0.2], scale=[0.5, 0.5], shear=-7)
            transforms.RandomAffine([-90, 90], translate=[0.2, 0.2], scale=[0.5, 0.5], shear=[-10])
            transforms.RandomAffine([-90, 90], translate=[0.2, 0.2], scale=[0.5, 0.5], shear=[-10, 0, 10])
ptrblck's avatar
ptrblck committed
1290
            transforms.RandomAffine([-90, 90], translate=[0.2, 0.2], scale=[0.5, 0.5], shear=[-10, 0, 10, 0, 10])
1291
1292
1293
1294

        x = np.zeros((100, 100, 3), dtype=np.uint8)
        img = F.to_pil_image(x)

ptrblck's avatar
ptrblck committed
1295
        t = transforms.RandomAffine(10, translate=[0.5, 0.3], scale=[0.7, 1.3], shear=[-10, 10, 20, 40])
1296
1297
1298
        for _ in range(100):
            angle, translations, scale, shear = t.get_params(t.degrees, t.translate, t.scale, t.shear,
                                                             img_size=img.size)
1299
1300
1301
1302
1303
1304
1305
1306
            self.assertTrue(-10 < angle < 10)
            self.assertTrue(-img.size[0] * 0.5 <= translations[0] <= img.size[0] * 0.5,
                            "{} vs {}".format(translations[0], img.size[0] * 0.5))
            self.assertTrue(-img.size[1] * 0.5 <= translations[1] <= img.size[1] * 0.5,
                            "{} vs {}".format(translations[1], img.size[1] * 0.5))
            self.assertTrue(0.7 < scale < 1.3)
            self.assertTrue(-10 < shear[0] < 10)
            self.assertTrue(-20 < shear[1] < 40)
1307
1308
1309
1310
1311

        # Checking if RandomAffine can be printed as string
        t.__repr__()

        t = transforms.RandomAffine(10, resample=Image.BILINEAR)
1312
        self.assertIn("Image.BILINEAR", t.__repr__())
1313

1314
1315
1316
1317
1318
1319
1320
1321
1322
1323
1324
1325
1326
1327
1328
    def test_to_grayscale(self):
        """Unit tests for grayscale transform"""

        x_shape = [2, 2, 3]
        x_data = [0, 5, 13, 54, 135, 226, 37, 8, 234, 90, 255, 1]
        x_np = np.array(x_data, dtype=np.uint8).reshape(x_shape)
        x_pil = Image.fromarray(x_np, mode='RGB')
        x_pil_2 = x_pil.convert('L')
        gray_np = np.array(x_pil_2)

        # Test Set: Grayscale an image with desired number of output channels
        # Case 1: RGB -> 1 channel grayscale
        trans1 = transforms.Grayscale(num_output_channels=1)
        gray_pil_1 = trans1(x_pil)
        gray_np_1 = np.array(gray_pil_1)
1329
1330
        self.assertEqual(gray_pil_1.mode, 'L', 'mode should be L')
        self.assertEqual(gray_np_1.shape, tuple(x_shape[0:2]), 'should be 1 channel')
1331
1332
1333
1334
1335
1336
        np.testing.assert_equal(gray_np, gray_np_1)

        # Case 2: RGB -> 3 channel grayscale
        trans2 = transforms.Grayscale(num_output_channels=3)
        gray_pil_2 = trans2(x_pil)
        gray_np_2 = np.array(gray_pil_2)
1337
1338
        self.assertEqual(gray_pil_2.mode, 'RGB', 'mode should be RGB')
        self.assertEqual(gray_np_2.shape, tuple(x_shape), 'should be 3 channel')
1339
1340
1341
1342
1343
1344
1345
1346
        np.testing.assert_equal(gray_np_2[:, :, 0], gray_np_2[:, :, 1])
        np.testing.assert_equal(gray_np_2[:, :, 1], gray_np_2[:, :, 2])
        np.testing.assert_equal(gray_np, gray_np_2[:, :, 0])

        # Case 3: 1 channel grayscale -> 1 channel grayscale
        trans3 = transforms.Grayscale(num_output_channels=1)
        gray_pil_3 = trans3(x_pil_2)
        gray_np_3 = np.array(gray_pil_3)
1347
1348
        self.assertEqual(gray_pil_3.mode, 'L', 'mode should be L')
        self.assertEqual(gray_np_3.shape, tuple(x_shape[0:2]), 'should be 1 channel')
1349
1350
1351
1352
1353
1354
        np.testing.assert_equal(gray_np, gray_np_3)

        # Case 4: 1 channel grayscale -> 3 channel grayscale
        trans4 = transforms.Grayscale(num_output_channels=3)
        gray_pil_4 = trans4(x_pil_2)
        gray_np_4 = np.array(gray_pil_4)
1355
1356
        self.assertEqual(gray_pil_4.mode, 'RGB', 'mode should be RGB')
        self.assertEqual(gray_np_4.shape, tuple(x_shape), 'should be 3 channel')
1357
1358
1359
1360
        np.testing.assert_equal(gray_np_4[:, :, 0], gray_np_4[:, :, 1])
        np.testing.assert_equal(gray_np_4[:, :, 1], gray_np_4[:, :, 2])
        np.testing.assert_equal(gray_np, gray_np_4[:, :, 0])

1361
1362
1363
        # Checking if Grayscale can be printed as string
        trans4.__repr__()

1364
1365
1366
1367
1368
1369
1370
1371
1372
1373
1374
1375
1376
1377
1378
1379
1380
1381
1382
    @unittest.skipIf(stats is None, 'scipy.stats not available')
    def test_random_grayscale(self):
        """Unit tests for random grayscale transform"""

        # Test Set 1: RGB -> 3 channel grayscale
        random_state = random.getstate()
        random.seed(42)
        x_shape = [2, 2, 3]
        x_np = np.random.randint(0, 256, x_shape, np.uint8)
        x_pil = Image.fromarray(x_np, mode='RGB')
        x_pil_2 = x_pil.convert('L')
        gray_np = np.array(x_pil_2)

        num_samples = 250
        num_gray = 0
        for _ in range(num_samples):
            gray_pil_2 = transforms.RandomGrayscale(p=0.5)(x_pil)
            gray_np_2 = np.array(gray_pil_2)
            if np.array_equal(gray_np_2[:, :, 0], gray_np_2[:, :, 1]) and \
1383
1384
                    np.array_equal(gray_np_2[:, :, 1], gray_np_2[:, :, 2]) and \
                    np.array_equal(gray_np, gray_np_2[:, :, 0]):
1385
1386
1387
1388
                num_gray = num_gray + 1

        p_value = stats.binom_test(num_gray, num_samples, p=0.5)
        random.setstate(random_state)
1389
        self.assertGreater(p_value, 0.0001)
1390
1391
1392
1393
1394
1395
1396
1397
1398
1399
1400
1401
1402
1403
1404
1405
1406
1407
1408
1409

        # Test Set 2: grayscale -> 1 channel grayscale
        random_state = random.getstate()
        random.seed(42)
        x_shape = [2, 2, 3]
        x_np = np.random.randint(0, 256, x_shape, np.uint8)
        x_pil = Image.fromarray(x_np, mode='RGB')
        x_pil_2 = x_pil.convert('L')
        gray_np = np.array(x_pil_2)

        num_samples = 250
        num_gray = 0
        for _ in range(num_samples):
            gray_pil_3 = transforms.RandomGrayscale(p=0.5)(x_pil_2)
            gray_np_3 = np.array(gray_pil_3)
            if np.array_equal(gray_np, gray_np_3):
                num_gray = num_gray + 1

        p_value = stats.binom_test(num_gray, num_samples, p=1.0)  # Note: grayscale is always unchanged
        random.setstate(random_state)
1410
        self.assertGreater(p_value, 0.0001)
1411
1412
1413
1414
1415
1416
1417
1418
1419
1420
1421
1422
1423

        # Test set 3: Explicit tests
        x_shape = [2, 2, 3]
        x_data = [0, 5, 13, 54, 135, 226, 37, 8, 234, 90, 255, 1]
        x_np = np.array(x_data, dtype=np.uint8).reshape(x_shape)
        x_pil = Image.fromarray(x_np, mode='RGB')
        x_pil_2 = x_pil.convert('L')
        gray_np = np.array(x_pil_2)

        # Case 3a: RGB -> 3 channel grayscale (grayscaled)
        trans2 = transforms.RandomGrayscale(p=1.0)
        gray_pil_2 = trans2(x_pil)
        gray_np_2 = np.array(gray_pil_2)
1424
1425
        self.assertEqual(gray_pil_2.mode, 'RGB', 'mode should be RGB')
        self.assertEqual(gray_np_2.shape, tuple(x_shape), 'should be 3 channel')
1426
1427
1428
1429
1430
1431
1432
1433
        np.testing.assert_equal(gray_np_2[:, :, 0], gray_np_2[:, :, 1])
        np.testing.assert_equal(gray_np_2[:, :, 1], gray_np_2[:, :, 2])
        np.testing.assert_equal(gray_np, gray_np_2[:, :, 0])

        # Case 3b: RGB -> 3 channel grayscale (unchanged)
        trans2 = transforms.RandomGrayscale(p=0.0)
        gray_pil_2 = trans2(x_pil)
        gray_np_2 = np.array(gray_pil_2)
1434
1435
        self.assertEqual(gray_pil_2.mode, 'RGB', 'mode should be RGB')
        self.assertEqual(gray_np_2.shape, tuple(x_shape), 'should be 3 channel')
1436
1437
1438
1439
1440
1441
        np.testing.assert_equal(x_np, gray_np_2)

        # Case 3c: 1 channel grayscale -> 1 channel grayscale (grayscaled)
        trans3 = transforms.RandomGrayscale(p=1.0)
        gray_pil_3 = trans3(x_pil_2)
        gray_np_3 = np.array(gray_pil_3)
1442
1443
        self.assertEqual(gray_pil_3.mode, 'L', 'mode should be L')
        self.assertEqual(gray_np_3.shape, tuple(x_shape[0:2]), 'should be 1 channel')
1444
1445
1446
1447
1448
1449
        np.testing.assert_equal(gray_np, gray_np_3)

        # Case 3d: 1 channel grayscale -> 1 channel grayscale (unchanged)
        trans3 = transforms.RandomGrayscale(p=0.0)
        gray_pil_3 = trans3(x_pil_2)
        gray_np_3 = np.array(gray_pil_3)
1450
1451
        self.assertEqual(gray_pil_3.mode, 'L', 'mode should be L')
        self.assertEqual(gray_np_3.shape, tuple(x_shape[0:2]), 'should be 1 channel')
1452
1453
        np.testing.assert_equal(gray_np, gray_np_3)

1454
1455
1456
        # Checking if RandomGrayscale can be printed as string
        trans3.__repr__()

1457
1458
1459
    def test_random_erasing(self):
        """Unit tests for random erasing transform"""

1460
        img = torch.rand([3, 60, 60])
1461
1462

        # Test Set 1: Erasing with int value
1463
1464
1465
        img_re = transforms.RandomErasing(value=0.2)
        i, j, h, w, v = img_re.get_params(img, scale=img_re.scale, ratio=img_re.ratio, value=img_re.value)
        img_output = F.erase(img, i, j, h, w, v)
1466
        self.assertEqual(img_output.size(0), 3)
1467
1468
1469
1470
1471
1472

        # Test Set 2: Check if the unerased region is preserved
        orig_unerased = img.clone()
        orig_unerased[:, i:i + h, j:j + w] = 0
        output_unerased = img_output.clone()
        output_unerased[:, i:i + h, j:j + w] = 0
1473
        self.assertTrue(torch.equal(orig_unerased, output_unerased))
1474
1475

        # Test Set 3: Erasing with random value
1476
        img_re = transforms.RandomErasing(value='random')(img)
1477
        self.assertEqual(img_re.size(0), 3)
1478

1479
        # Test Set 4: Erasing with tuple value
1480
        img_re = transforms.RandomErasing(value=(0.2, 0.2, 0.2))(img)
1481
        self.assertEqual(img_re.size(0), 3)
1482

1483
1484
        # Test Set 5: Testing the inplace behaviour
        img_re = transforms.RandomErasing(value=(0.2), inplace=True)(img)
1485
        self.assertTrue(torch.equal(img_re, img))
1486

Zhun Zhong's avatar
Zhun Zhong committed
1487
1488
1489
        # Test Set 6: Checking when no erased region is selected
        img = torch.rand([3, 300, 1])
        img_re = transforms.RandomErasing(ratio=(0.1, 0.2), value='random')(img)
1490
        self.assertTrue(torch.equal(img_re, img))
Zhun Zhong's avatar
Zhun Zhong committed
1491

1492

1493
1494
if __name__ == '__main__':
    unittest.main()