test_functional_tensor.py 44.5 KB
Newer Older
1
import os
2
import unittest
3
import colorsys
4
import math
5

vfdev's avatar
vfdev committed
6
7
8
9
10
11
import numpy as np

import torch
import torchvision.transforms.functional_tensor as F_t
import torchvision.transforms.functional_pil as F_pil
import torchvision.transforms.functional as F
12
from torchvision.transforms import InterpolationMode
13

14
from common_utils import TransformsTester
15

16
from typing import Dict, List, Sequence, Tuple
17

18

19
NEAREST, BILINEAR, BICUBIC = InterpolationMode.NEAREST, InterpolationMode.BILINEAR, InterpolationMode.BICUBIC
20
21


22
class Tester(TransformsTester):
vfdev's avatar
vfdev committed
23

24
25
26
    def setUp(self):
        self.device = "cpu"

27
    def _test_fn_on_batch(self, batch_tensors, fn, scripted_fn_atol=1e-8, **fn_kwargs):
28
29
30
31
32
33
        transformed_batch = fn(batch_tensors, **fn_kwargs)
        for i in range(len(batch_tensors)):
            img_tensor = batch_tensors[i, ...]
            transformed_img = fn(img_tensor, **fn_kwargs)
            self.assertTrue(transformed_img.equal(transformed_batch[i, ...]))

34
35
36
37
38
        if scripted_fn_atol >= 0:
            scripted_fn = torch.jit.script(fn)
            # scriptable function test
            s_transformed_batch = scripted_fn(batch_tensors, **fn_kwargs)
            self.assertTrue(transformed_batch.allclose(s_transformed_batch, atol=scripted_fn_atol))
39

40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
    def test_assert_image_tensor(self):
        shape = (100,)
        tensor = torch.rand(*shape, dtype=torch.float, device=self.device)

        list_of_methods = [(F_t._get_image_size, (tensor, )), (F_t.vflip, (tensor, )),
                           (F_t.hflip, (tensor, )), (F_t.crop, (tensor, 1, 2, 4, 5)),
                           (F_t.adjust_brightness, (tensor, 0.)), (F_t.adjust_contrast, (tensor, 1.)),
                           (F_t.adjust_hue, (tensor, -0.5)), (F_t.adjust_saturation, (tensor, 2.)),
                           (F_t.center_crop, (tensor, [10, 11])), (F_t.five_crop, (tensor, [10, 11])),
                           (F_t.ten_crop, (tensor, [10, 11])), (F_t.pad, (tensor, [2, ], 2, "constant")),
                           (F_t.resize, (tensor, [10, 11])), (F_t.perspective, (tensor, [0.2, ])),
                           (F_t.gaussian_blur, (tensor, (2, 2), (0.7, 0.5))),
                           (F_t.invert, (tensor, )), (F_t.posterize, (tensor, 0)),
                           (F_t.solarize, (tensor, 0.3)), (F_t.adjust_sharpness, (tensor, 0.3)),
                           (F_t.autocontrast, (tensor, )), (F_t.equalize, (tensor, ))]

        for func, args in list_of_methods:
            with self.assertRaises(Exception) as context:
                func(*args)

            self.assertTrue('Tensor is not a torch image.' in str(context.exception))

62
    def test_vflip(self):
63
64
65
66
67
68
69
        script_vflip = torch.jit.script(F.vflip)

        img_tensor, pil_img = self._create_data(16, 18, device=self.device)
        vflipped_img = F.vflip(img_tensor)
        vflipped_pil_img = F.vflip(pil_img)
        self.compareTensorToPIL(vflipped_img, vflipped_pil_img)

70
71
        # scriptable function test
        vflipped_img_script = script_vflip(img_tensor)
72
73
74
75
        self.assertTrue(vflipped_img.equal(vflipped_img_script))

        batch_tensors = self._create_data_batch(16, 18, num_samples=4, device=self.device)
        self._test_fn_on_batch(batch_tensors, F.vflip)
76

77
    def test_hflip(self):
78
79
80
81
82
83
84
        script_hflip = torch.jit.script(F.hflip)

        img_tensor, pil_img = self._create_data(16, 18, device=self.device)
        hflipped_img = F.hflip(img_tensor)
        hflipped_pil_img = F.hflip(pil_img)
        self.compareTensorToPIL(hflipped_img, hflipped_pil_img)

85
86
        # scriptable function test
        hflipped_img_script = script_hflip(img_tensor)
87
88
89
90
        self.assertTrue(hflipped_img.equal(hflipped_img_script))

        batch_tensors = self._create_data_batch(16, 18, num_samples=4, device=self.device)
        self._test_fn_on_batch(batch_tensors, F.hflip)
91

92
    def test_crop(self):
93
        script_crop = torch.jit.script(F.crop)
94

95
        img_tensor, pil_img = self._create_data(16, 18, device=self.device)
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111

        test_configs = [
            (1, 2, 4, 5),   # crop inside top-left corner
            (2, 12, 3, 4),  # crop inside top-right corner
            (8, 3, 5, 6),   # crop inside bottom-left corner
            (8, 11, 4, 3),  # crop inside bottom-right corner
        ]

        for top, left, height, width in test_configs:
            pil_img_cropped = F.crop(pil_img, top, left, height, width)

            img_tensor_cropped = F.crop(img_tensor, top, left, height, width)
            self.compareTensorToPIL(img_tensor_cropped, pil_img_cropped)

            img_tensor_cropped = script_crop(img_tensor, top, left, height, width)
            self.compareTensorToPIL(img_tensor_cropped, pil_img_cropped)
ekka's avatar
ekka committed
112

113
114
115
            batch_tensors = self._create_data_batch(16, 18, num_samples=4, device=self.device)
            self._test_fn_on_batch(batch_tensors, F.crop, top=top, left=left, height=height, width=width)

116
    def test_hsv2rgb(self):
117
        scripted_fn = torch.jit.script(F_t._hsv2rgb)
118
        shape = (3, 100, 150)
119
120
121
122
        for _ in range(10):
            hsv_img = torch.rand(*shape, dtype=torch.float, device=self.device)
            rgb_img = F_t._hsv2rgb(hsv_img)
            ft_img = rgb_img.permute(1, 2, 0).flatten(0, 1)
123

124
125
126
127
            h, s, v, = hsv_img.unbind(0)
            h = h.flatten().cpu().numpy()
            s = s.flatten().cpu().numpy()
            v = v.flatten().cpu().numpy()
128
129
130
131

            rgb = []
            for h1, s1, v1 in zip(h, s, v):
                rgb.append(colorsys.hsv_to_rgb(h1, s1, v1))
132
            colorsys_img = torch.tensor(rgb, dtype=torch.float32, device=self.device)
133
134
135
            max_diff = (ft_img - colorsys_img).abs().max()
            self.assertLess(max_diff, 1e-5)

136
137
138
            s_rgb_img = scripted_fn(hsv_img)
            self.assertTrue(rgb_img.allclose(s_rgb_img))

139
140
141
        batch_tensors = self._create_data_batch(120, 100, num_samples=4, device=self.device).float()
        self._test_fn_on_batch(batch_tensors, F_t._hsv2rgb)

142
    def test_rgb2hsv(self):
143
        scripted_fn = torch.jit.script(F_t._rgb2hsv)
144
        shape = (3, 150, 100)
145
146
147
148
        for _ in range(10):
            rgb_img = torch.rand(*shape, dtype=torch.float, device=self.device)
            hsv_img = F_t._rgb2hsv(rgb_img)
            ft_hsv_img = hsv_img.permute(1, 2, 0).flatten(0, 1)
149

150
            r, g, b, = rgb_img.unbind(dim=-3)
151
152
153
            r = r.flatten().cpu().numpy()
            g = g.flatten().cpu().numpy()
            b = b.flatten().cpu().numpy()
154
155
156
157
158

            hsv = []
            for r1, g1, b1 in zip(r, g, b):
                hsv.append(colorsys.rgb_to_hsv(r1, g1, b1))

159
            colorsys_img = torch.tensor(hsv, dtype=torch.float32, device=self.device)
160

161
162
163
164
165
166
            ft_hsv_img_h, ft_hsv_img_sv = torch.split(ft_hsv_img, [1, 2], dim=1)
            colorsys_img_h, colorsys_img_sv = torch.split(colorsys_img, [1, 2], dim=1)

            max_diff_h = ((colorsys_img_h * 2 * math.pi).sin() - (ft_hsv_img_h * 2 * math.pi).sin()).abs().max()
            max_diff_sv = (colorsys_img_sv - ft_hsv_img_sv).abs().max()
            max_diff = max(max_diff_h, max_diff_sv)
167
168
            self.assertLess(max_diff, 1e-5)

169
            s_hsv_img = scripted_fn(rgb_img)
170
            self.assertTrue(hsv_img.allclose(s_hsv_img, atol=1e-7))
171

172
173
174
        batch_tensors = self._create_data_batch(120, 100, num_samples=4, device=self.device).float()
        self._test_fn_on_batch(batch_tensors, F_t._rgb2hsv)

175
    def test_rgb_to_grayscale(self):
176
177
        script_rgb_to_grayscale = torch.jit.script(F.rgb_to_grayscale)

178
        img_tensor, pil_img = self._create_data(32, 34, device=self.device)
179
180
181
182
183
184
185
186
187
188

        for num_output_channels in (3, 1):
            gray_pil_image = F.rgb_to_grayscale(pil_img, num_output_channels=num_output_channels)
            gray_tensor = F.rgb_to_grayscale(img_tensor, num_output_channels=num_output_channels)

            self.approxEqualTensorToPIL(gray_tensor.float(), gray_pil_image, tol=1.0 + 1e-10, agg_method="max")

            s_gray_tensor = script_rgb_to_grayscale(img_tensor, num_output_channels=num_output_channels)
            self.assertTrue(s_gray_tensor.equal(gray_tensor))

189
190
191
            batch_tensors = self._create_data_batch(16, 18, num_samples=4, device=self.device)
            self._test_fn_on_batch(batch_tensors, F.rgb_to_grayscale, num_output_channels=num_output_channels)

192
    def test_center_crop(self):
193
194
        script_center_crop = torch.jit.script(F.center_crop)

195
        img_tensor, pil_img = self._create_data(32, 34, device=self.device)
196
197
198
199
200
201
202
203

        cropped_pil_image = F.center_crop(pil_img, [10, 11])

        cropped_tensor = F.center_crop(img_tensor, [10, 11])
        self.compareTensorToPIL(cropped_tensor, cropped_pil_image)

        cropped_tensor = script_center_crop(img_tensor, [10, 11])
        self.compareTensorToPIL(cropped_tensor, cropped_pil_image)
204

205
206
207
        batch_tensors = self._create_data_batch(16, 18, num_samples=4, device=self.device)
        self._test_fn_on_batch(batch_tensors, F.center_crop, output_size=[10, 11])

208
    def test_five_crop(self):
209
210
        script_five_crop = torch.jit.script(F.five_crop)

211
        img_tensor, pil_img = self._create_data(32, 34, device=self.device)
212
213
214
215
216
217
218
219
220
221

        cropped_pil_images = F.five_crop(pil_img, [10, 11])

        cropped_tensors = F.five_crop(img_tensor, [10, 11])
        for i in range(5):
            self.compareTensorToPIL(cropped_tensors[i], cropped_pil_images[i])

        cropped_tensors = script_five_crop(img_tensor, [10, 11])
        for i in range(5):
            self.compareTensorToPIL(cropped_tensors[i], cropped_pil_images[i])
222

223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
        batch_tensors = self._create_data_batch(16, 18, num_samples=4, device=self.device)
        tuple_transformed_batches = F.five_crop(batch_tensors, [10, 11])
        for i in range(len(batch_tensors)):
            img_tensor = batch_tensors[i, ...]
            tuple_transformed_imgs = F.five_crop(img_tensor, [10, 11])
            self.assertEqual(len(tuple_transformed_imgs), len(tuple_transformed_batches))

            for j in range(len(tuple_transformed_imgs)):
                true_transformed_img = tuple_transformed_imgs[j]
                transformed_img = tuple_transformed_batches[j][i, ...]
                self.assertTrue(true_transformed_img.equal(transformed_img))

        # scriptable function test
        s_tuple_transformed_batches = script_five_crop(batch_tensors, [10, 11])
        for transformed_batch, s_transformed_batch in zip(tuple_transformed_batches, s_tuple_transformed_batches):
            self.assertTrue(transformed_batch.equal(s_transformed_batch))

240
    def test_ten_crop(self):
241
242
        script_ten_crop = torch.jit.script(F.ten_crop)

243
        img_tensor, pil_img = self._create_data(32, 34, device=self.device)
244
245
246
247
248
249
250
251
252
253

        cropped_pil_images = F.ten_crop(pil_img, [10, 11])

        cropped_tensors = F.ten_crop(img_tensor, [10, 11])
        for i in range(10):
            self.compareTensorToPIL(cropped_tensors[i], cropped_pil_images[i])

        cropped_tensors = script_ten_crop(img_tensor, [10, 11])
        for i in range(10):
            self.compareTensorToPIL(cropped_tensors[i], cropped_pil_images[i])
254

255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
        batch_tensors = self._create_data_batch(16, 18, num_samples=4, device=self.device)
        tuple_transformed_batches = F.ten_crop(batch_tensors, [10, 11])
        for i in range(len(batch_tensors)):
            img_tensor = batch_tensors[i, ...]
            tuple_transformed_imgs = F.ten_crop(img_tensor, [10, 11])
            self.assertEqual(len(tuple_transformed_imgs), len(tuple_transformed_batches))

            for j in range(len(tuple_transformed_imgs)):
                true_transformed_img = tuple_transformed_imgs[j]
                transformed_img = tuple_transformed_batches[j][i, ...]
                self.assertTrue(true_transformed_img.equal(transformed_img))

        # scriptable function test
        s_tuple_transformed_batches = script_ten_crop(batch_tensors, [10, 11])
        for transformed_batch, s_transformed_batch in zip(tuple_transformed_batches, s_tuple_transformed_batches):
            self.assertTrue(transformed_batch.equal(s_transformed_batch))

272
    def test_pad(self):
273
        script_fn = torch.jit.script(F.pad)
274
        tensor, pil_img = self._create_data(7, 8, device=self.device)
275
        batch_tensors = self._create_data_batch(16, 18, num_samples=4, device=self.device)
276

277
278
279
280
281
282
        for dt in [None, torch.float32, torch.float64, torch.float16]:

            if dt == torch.float16 and torch.device(self.device).type == "cpu":
                # skip float16 on CPU case
                continue

283
284
285
            if dt is not None:
                # This is a trivial cast to float of uint8 data to test all cases
                tensor = tensor.to(dt)
286
287
                batch_tensors = batch_tensors.to(dt)

288
289
290
291
292
293
294
            for pad in [2, [3, ], [0, 3], (3, 3), [4, 2, 4, 3]]:
                configs = [
                    {"padding_mode": "constant", "fill": 0},
                    {"padding_mode": "constant", "fill": 10},
                    {"padding_mode": "constant", "fill": 20},
                    {"padding_mode": "edge"},
                    {"padding_mode": "reflect"},
295
                    {"padding_mode": "symmetric"},
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
                ]
                for kwargs in configs:
                    pad_tensor = F_t.pad(tensor, pad, **kwargs)
                    pad_pil_img = F_pil.pad(pil_img, pad, **kwargs)

                    pad_tensor_8b = pad_tensor
                    # we need to cast to uint8 to compare with PIL image
                    if pad_tensor_8b.dtype != torch.uint8:
                        pad_tensor_8b = pad_tensor_8b.to(torch.uint8)

                    self.compareTensorToPIL(pad_tensor_8b, pad_pil_img, msg="{}, {}".format(pad, kwargs))

                    if isinstance(pad, int):
                        script_pad = [pad, ]
                    else:
                        script_pad = pad
                    pad_tensor_script = script_fn(tensor, script_pad, **kwargs)
                    self.assertTrue(pad_tensor.equal(pad_tensor_script), msg="{}, {}".format(pad, kwargs))
314

315
316
                    self._test_fn_on_batch(batch_tensors, F.pad, padding=script_pad, **kwargs)

317
318
    def _test_adjust_fn(self, fn, fn_pil, fn_t, configs, tol=2.0 + 1e-10, agg_method="max",
                        dts=(None, torch.float32, torch.float64)):
vfdev's avatar
vfdev committed
319
320
321
        script_fn = torch.jit.script(fn)
        torch.manual_seed(15)
        tensor, pil_img = self._create_data(26, 34, device=self.device)
322
        batch_tensors = self._create_data_batch(16, 18, num_samples=4, device=self.device)
vfdev's avatar
vfdev committed
323

324
        for dt in dts:
325
326
327

            if dt is not None:
                tensor = F.convert_image_dtype(tensor, dt)
328
                batch_tensors = F.convert_image_dtype(batch_tensors, dt)
329

vfdev's avatar
vfdev committed
330
331
332
333
334
335
336
            for config in configs:
                adjusted_tensor = fn_t(tensor, **config)
                adjusted_pil = fn_pil(pil_img, **config)
                scripted_result = script_fn(tensor, **config)
                msg = "{}, {}".format(dt, config)
                self.assertEqual(adjusted_tensor.dtype, scripted_result.dtype, msg=msg)
                self.assertEqual(adjusted_tensor.size()[1:], adjusted_pil.size[::-1], msg=msg)
337
338

                rbg_tensor = adjusted_tensor
vfdev's avatar
vfdev committed
339

340
341
342
                if adjusted_tensor.dtype != torch.uint8:
                    rbg_tensor = F.convert_image_dtype(adjusted_tensor, torch.uint8)

vfdev's avatar
vfdev committed
343
344
                # Check that max difference does not exceed 2 in [0, 255] range
                # Exact matching is not possible due to incompatibility convert_image_dtype and PIL results
345
346
347
348
349
350
                self.approxEqualTensorToPIL(rbg_tensor.float(), adjusted_pil, tol=tol, msg=msg, agg_method=agg_method)

                atol = 1e-6
                if adjusted_tensor.dtype == torch.uint8 and "cuda" in torch.device(self.device).type:
                    atol = 1.0
                self.assertTrue(adjusted_tensor.allclose(scripted_result, atol=atol), msg=msg)
vfdev's avatar
vfdev committed
351

352
                self._test_fn_on_batch(batch_tensors, fn, scripted_fn_atol=atol, **config)
353

vfdev's avatar
vfdev committed
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
    def test_adjust_brightness(self):
        self._test_adjust_fn(
            F.adjust_brightness,
            F_pil.adjust_brightness,
            F_t.adjust_brightness,
            [{"brightness_factor": f} for f in [0.1, 0.5, 1.0, 1.34, 2.5]]
        )

    def test_adjust_contrast(self):
        self._test_adjust_fn(
            F.adjust_contrast,
            F_pil.adjust_contrast,
            F_t.adjust_contrast,
            [{"contrast_factor": f} for f in [0.2, 0.5, 1.0, 1.5, 2.0]]
        )

    def test_adjust_saturation(self):
        self._test_adjust_fn(
            F.adjust_saturation,
            F_pil.adjust_saturation,
            F_t.adjust_saturation,
            [{"saturation_factor": f} for f in [0.5, 0.75, 1.0, 1.5, 2.0]]
        )
377

378
379
380
381
382
383
    def test_adjust_hue(self):
        self._test_adjust_fn(
            F.adjust_hue,
            F_pil.adjust_hue,
            F_t.adjust_hue,
            [{"hue_factor": f} for f in [-0.45, -0.25, 0.0, 0.25, 0.45]],
vfdev's avatar
vfdev committed
384
385
            tol=16.1,
            agg_method="max"
386
387
        )

vfdev's avatar
vfdev committed
388
389
390
391
392
393
394
    def test_adjust_gamma(self):
        self._test_adjust_fn(
            F.adjust_gamma,
            F_pil.adjust_gamma,
            F_t.adjust_gamma,
            [{"gamma": g1, "gain": g2} for g1, g2 in zip([0.8, 1.0, 1.2], [0.7, 1.0, 1.3])]
        )
395

396
    def test_resize(self):
397
        script_fn = torch.jit.script(F.resize)
398
        tensor, pil_img = self._create_data(26, 36, device=self.device)
399
        batch_tensors = self._create_data_batch(16, 18, num_samples=4, device=self.device)
vfdev's avatar
vfdev committed
400

401
402
403
404
405
406
        for dt in [None, torch.float32, torch.float64, torch.float16]:

            if dt == torch.float16 and torch.device(self.device).type == "cpu":
                # skip float16 on CPU case
                continue

vfdev's avatar
vfdev committed
407
408
409
            if dt is not None:
                # This is a trivial cast to float of uint8 data to test all cases
                tensor = tensor.to(dt)
410
411
                batch_tensors = batch_tensors.to(dt)

412
            for size in [32, 26, [32, ], [32, 32], (32, 32), [26, 35]]:
413
414
415
416
417
418
419
420
421
422
                for max_size in (None, 33, 40, 1000):
                    if max_size is not None and isinstance(size, Sequence) and len(size) != 1:
                        continue  # unsupported, see assertRaises below
                    for interpolation in [BILINEAR, BICUBIC, NEAREST]:
                        resized_tensor = F.resize(tensor, size=size, interpolation=interpolation, max_size=max_size)
                        resized_pil_img = F.resize(pil_img, size=size, interpolation=interpolation, max_size=max_size)

                        self.assertEqual(
                            resized_tensor.size()[1:], resized_pil_img.size[::-1],
                            msg="{}, {}".format(size, interpolation)
vfdev's avatar
vfdev committed
423
424
                        )

425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
                        if interpolation not in [NEAREST, ]:
                            # We can not check values if mode = NEAREST, as results are different
                            # E.g. resized_tensor  = [[a, a, b, c, d, d, e, ...]]
                            # E.g. resized_pil_img = [[a, b, c, c, d, e, f, ...]]
                            resized_tensor_f = resized_tensor
                            # we need to cast to uint8 to compare with PIL image
                            if resized_tensor_f.dtype == torch.uint8:
                                resized_tensor_f = resized_tensor_f.to(torch.float)

                            # Pay attention to high tolerance for MAE
                            self.approxEqualTensorToPIL(
                                resized_tensor_f, resized_pil_img, tol=8.0, msg="{}, {}".format(size, interpolation)
                            )

                        if isinstance(size, int):
                            script_size = [size, ]
                        else:
                            script_size = size

                        resize_result = script_fn(tensor, size=script_size, interpolation=interpolation,
                                                  max_size=max_size)
                        self.assertTrue(resized_tensor.equal(resize_result), msg="{}, {}".format(size, interpolation))

                        self._test_fn_on_batch(
                            batch_tensors, F.resize, size=script_size, interpolation=interpolation, max_size=max_size
                        )
451

452
        # assert changed type warning
453
        with self.assertWarnsRegex(UserWarning, r"Argument interpolation should be of type InterpolationMode"):
454
455
456
457
            res1 = F.resize(tensor, size=32, interpolation=2)
            res2 = F.resize(tensor, size=32, interpolation=BILINEAR)
            self.assertTrue(res1.equal(res2))

458
459
460
461
462
463
464
        for img in (tensor, pil_img):
            exp_msg = "max_size should only be passed if size specifies the length of the smaller edge"
            with self.assertRaisesRegex(ValueError, exp_msg):
                F.resize(img, size=(32, 34), max_size=35)
            with self.assertRaisesRegex(ValueError, "max_size = 32 must be strictly greater"):
                F.resize(img, size=32, max_size=32)

465
    def test_resized_crop(self):
466
467
        # test values of F.resized_crop in several cases:
        # 1) resize to the same size, crop to the same size => should be identity
468
        tensor, _ = self._create_data(26, 36, device=self.device)
469
470
471

        for mode in [NEAREST, BILINEAR, BICUBIC]:
            out_tensor = F.resized_crop(tensor, top=0, left=0, height=26, width=36, size=[26, 36], interpolation=mode)
472
473
474
            self.assertTrue(tensor.equal(out_tensor), msg="{} vs {}".format(out_tensor[0, :5, :5], tensor[0, :5, :5]))

        # 2) resize by half and crop a TL corner
475
        tensor, _ = self._create_data(26, 36, device=self.device)
476
        out_tensor = F.resized_crop(tensor, top=0, left=0, height=20, width=30, size=[10, 15], interpolation=NEAREST)
477
478
479
480
481
482
        expected_out_tensor = tensor[:, :20:2, :30:2]
        self.assertTrue(
            expected_out_tensor.equal(out_tensor),
            msg="{} vs {}".format(expected_out_tensor[0, :10, :10], out_tensor[0, :10, :10])
        )

483
484
        batch_tensors = self._create_data_batch(26, 36, num_samples=4, device=self.device)
        self._test_fn_on_batch(
485
            batch_tensors, F.resized_crop, top=1, left=2, height=20, width=30, size=[10, 15], interpolation=NEAREST
486
487
        )

488
489
    def _test_affine_identity_map(self, tensor, scripted_affine):
        # 1) identity map
490
        out_tensor = F.affine(tensor, angle=0, translate=[0, 0], scale=1.0, shear=[0.0, 0.0], interpolation=NEAREST)
vfdev's avatar
vfdev committed
491

492
493
494
        self.assertTrue(
            tensor.equal(out_tensor), msg="{} vs {}".format(out_tensor[0, :5, :5], tensor[0, :5, :5])
        )
495
496
497
        out_tensor = scripted_affine(
            tensor, angle=0, translate=[0, 0], scale=1.0, shear=[0.0, 0.0], interpolation=NEAREST
        )
498
499
500
        self.assertTrue(
            tensor.equal(out_tensor), msg="{} vs {}".format(out_tensor[0, :5, :5], tensor[0, :5, :5])
        )
501

502
503
504
505
506
507
508
509
510
511
512
513
514
    def _test_affine_square_rotations(self, tensor, pil_img, scripted_affine):
        # 2) Test rotation
        test_configs = [
            (90, torch.rot90(tensor, k=1, dims=(-1, -2))),
            (45, None),
            (30, None),
            (-30, None),
            (-45, None),
            (-90, torch.rot90(tensor, k=-1, dims=(-1, -2))),
            (180, torch.rot90(tensor, k=2, dims=(-1, -2))),
        ]
        for a, true_tensor in test_configs:
            out_pil_img = F.affine(
515
                pil_img, angle=a, translate=[0, 0], scale=1.0, shear=[0.0, 0.0], interpolation=NEAREST
516
            )
517
518
519
520
            out_pil_tensor = torch.from_numpy(np.array(out_pil_img).transpose((2, 0, 1))).to(self.device)

            for fn in [F.affine, scripted_affine]:
                out_tensor = fn(
521
                    tensor, angle=a, translate=[0, 0], scale=1.0, shear=[0.0, 0.0], interpolation=NEAREST
522
523
524
525
526
527
                )
                if true_tensor is not None:
                    self.assertTrue(
                        true_tensor.equal(out_tensor),
                        msg="{}\n{} vs \n{}".format(a, out_tensor[0, :5, :5], true_tensor[0, :5, :5])
                    )
528

529
530
531
532
533
534
535
536
537
538
539
                if out_tensor.dtype != torch.uint8:
                    out_tensor = out_tensor.to(torch.uint8)

                num_diff_pixels = (out_tensor != out_pil_tensor).sum().item() / 3.0
                ratio_diff_pixels = num_diff_pixels / out_tensor.shape[-1] / out_tensor.shape[-2]
                # Tolerance : less than 6% of different pixels
                self.assertLess(
                    ratio_diff_pixels,
                    0.06,
                    msg="{}\n{} vs \n{}".format(
                        ratio_diff_pixels, out_tensor[0, :7, :7], out_pil_tensor[0, :7, :7]
540
                    )
541
                )
542

543
544
545
546
547
    def _test_affine_rect_rotations(self, tensor, pil_img, scripted_affine):
        test_configs = [
            90, 45, 15, -30, -60, -120
        ]
        for a in test_configs:
548

549
            out_pil_img = F.affine(
550
                pil_img, angle=a, translate=[0, 0], scale=1.0, shear=[0.0, 0.0], interpolation=NEAREST
551
552
553
554
555
            )
            out_pil_tensor = torch.from_numpy(np.array(out_pil_img).transpose((2, 0, 1)))

            for fn in [F.affine, scripted_affine]:
                out_tensor = fn(
556
                    tensor, angle=a, translate=[0, 0], scale=1.0, shear=[0.0, 0.0], interpolation=NEAREST
557
558
559
560
561
562
563
564
565
566
567
568
569
                ).cpu()

                if out_tensor.dtype != torch.uint8:
                    out_tensor = out_tensor.to(torch.uint8)

                num_diff_pixels = (out_tensor != out_pil_tensor).sum().item() / 3.0
                ratio_diff_pixels = num_diff_pixels / out_tensor.shape[-1] / out_tensor.shape[-2]
                # Tolerance : less than 3% of different pixels
                self.assertLess(
                    ratio_diff_pixels,
                    0.03,
                    msg="{}: {}\n{} vs \n{}".format(
                        a, ratio_diff_pixels, out_tensor[0, :7, :7], out_pil_tensor[0, :7, :7]
570
                    )
571
                )
572

573
574
575
576
577
578
    def _test_affine_translations(self, tensor, pil_img, scripted_affine):
        # 3) Test translation
        test_configs = [
            [10, 12], (-12, -13)
        ]
        for t in test_configs:
579

580
            out_pil_img = F.affine(pil_img, angle=0, translate=t, scale=1.0, shear=[0.0, 0.0], interpolation=NEAREST)
581

582
            for fn in [F.affine, scripted_affine]:
583
                out_tensor = fn(tensor, angle=0, translate=t, scale=1.0, shear=[0.0, 0.0], interpolation=NEAREST)
584

585
586
                if out_tensor.dtype != torch.uint8:
                    out_tensor = out_tensor.to(torch.uint8)
587

588
589
590
591
592
                self.compareTensorToPIL(out_tensor, out_pil_img)

    def _test_affine_all_ops(self, tensor, pil_img, scripted_affine):
        # 4) Test rotation + translation + scale + share
        test_configs = [
593
594
595
596
597
598
599
600
601
602
            (45.5, [5, 6], 1.0, [0.0, 0.0], None),
            (33, (5, -4), 1.0, [0.0, 0.0], [0, 0, 0]),
            (45, [-5, 4], 1.2, [0.0, 0.0], (1, 2, 3)),
            (33, (-4, -8), 2.0, [0.0, 0.0], [255, 255, 255]),
            (85, (10, -10), 0.7, [0.0, 0.0], [1, ]),
            (0, [0, 0], 1.0, [35.0, ], (2.0, )),
            (-25, [0, 0], 1.2, [0.0, 15.0], None),
            (-45, [-10, 0], 0.7, [2.0, 5.0], None),
            (-45, [-10, -10], 1.2, [4.0, 5.0], None),
            (-90, [0, 0], 1.0, [0.0, 0.0], None),
603
        ]
604
        for r in [NEAREST, ]:
605
606
607
            for a, t, s, sh, f in test_configs:
                f_pil = int(f[0]) if f is not None and len(f) == 1 else f
                out_pil_img = F.affine(pil_img, angle=a, translate=t, scale=s, shear=sh, interpolation=r, fill=f_pil)
608
609
610
                out_pil_tensor = torch.from_numpy(np.array(out_pil_img).transpose((2, 0, 1)))

                for fn in [F.affine, scripted_affine]:
611
                    out_tensor = fn(tensor, angle=a, translate=t, scale=s, shear=sh, interpolation=r, fill=f).cpu()
612
613
614
615
616
617
618
619
620
621
622
623

                    if out_tensor.dtype != torch.uint8:
                        out_tensor = out_tensor.to(torch.uint8)

                    num_diff_pixels = (out_tensor != out_pil_tensor).sum().item() / 3.0
                    ratio_diff_pixels = num_diff_pixels / out_tensor.shape[-1] / out_tensor.shape[-2]
                    # Tolerance : less than 5% (cpu), 6% (cuda) of different pixels
                    tol = 0.06 if self.device == "cuda" else 0.05
                    self.assertLess(
                        ratio_diff_pixels,
                        tol,
                        msg="{}: {}\n{} vs \n{}".format(
624
                            (r, a, t, s, sh, f), ratio_diff_pixels, out_tensor[0, :7, :7], out_pil_tensor[0, :7, :7]
vfdev's avatar
vfdev committed
625
                        )
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
                    )

    def test_affine(self):
        # Tests on square and rectangular images
        scripted_affine = torch.jit.script(F.affine)

        data = [self._create_data(26, 26, device=self.device), self._create_data(32, 26, device=self.device)]
        for tensor, pil_img in data:

            for dt in [None, torch.float32, torch.float64, torch.float16]:

                if dt == torch.float16 and torch.device(self.device).type == "cpu":
                    # skip float16 on CPU case
                    continue

                if dt is not None:
                    tensor = tensor.to(dtype=dt)

                self._test_affine_identity_map(tensor, scripted_affine)
                if pil_img.size[0] == pil_img.size[1]:
                    self._test_affine_square_rotations(tensor, pil_img, scripted_affine)
                else:
                    self._test_affine_rect_rotations(tensor, pil_img, scripted_affine)
                self._test_affine_translations(tensor, pil_img, scripted_affine)
650
651
652
653
654
655
656
657
658
659
                self._test_affine_all_ops(tensor, pil_img, scripted_affine)

                batch_tensors = self._create_data_batch(26, 36, num_samples=4, device=self.device)
                if dt is not None:
                    batch_tensors = batch_tensors.to(dtype=dt)

                self._test_fn_on_batch(
                    batch_tensors, F.affine, angle=-43, translate=[-3, 4], scale=1.2, shear=[4.0, 5.0]
                )

660
661
662
663
664
665
666
667
        tensor, pil_img = data[0]
        # assert deprecation warning and non-BC
        with self.assertWarnsRegex(UserWarning, r"Argument resample is deprecated and will be removed"):
            res1 = F.affine(tensor, 45, translate=[0, 0], scale=1.0, shear=[0.0, 0.0], resample=2)
            res2 = F.affine(tensor, 45, translate=[0, 0], scale=1.0, shear=[0.0, 0.0], interpolation=BILINEAR)
            self.assertTrue(res1.equal(res2))

        # assert changed type warning
668
        with self.assertWarnsRegex(UserWarning, r"Argument interpolation should be of type InterpolationMode"):
669
670
671
672
673
674
675
676
677
            res1 = F.affine(tensor, 45, translate=[0, 0], scale=1.0, shear=[0.0, 0.0], interpolation=2)
            res2 = F.affine(tensor, 45, translate=[0, 0], scale=1.0, shear=[0.0, 0.0], interpolation=BILINEAR)
            self.assertTrue(res1.equal(res2))

        with self.assertWarnsRegex(UserWarning, r"Argument fillcolor is deprecated and will be removed"):
            res1 = F.affine(pil_img, 45, translate=[0, 0], scale=1.0, shear=[0.0, 0.0], fillcolor=10)
            res2 = F.affine(pil_img, 45, translate=[0, 0], scale=1.0, shear=[0.0, 0.0], fill=10)
            self.assertEqual(res1, res2)

678
679
680
    def _test_rotate_all_options(self, tensor, pil_img, scripted_rotate, centers):
        img_size = pil_img.size
        dt = tensor.dtype
681
        for r in [NEAREST, ]:
682
683
684
            for a in range(-180, 180, 17):
                for e in [True, False]:
                    for c in centers:
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
                        for f in [None, [0, 0, 0], (1, 2, 3), [255, 255, 255], [1, ], (2.0, )]:
                            f_pil = int(f[0]) if f is not None and len(f) == 1 else f
                            out_pil_img = F.rotate(pil_img, angle=a, interpolation=r, expand=e, center=c, fill=f_pil)
                            out_pil_tensor = torch.from_numpy(np.array(out_pil_img).transpose((2, 0, 1)))
                            for fn in [F.rotate, scripted_rotate]:
                                out_tensor = fn(tensor, angle=a, interpolation=r, expand=e, center=c, fill=f).cpu()

                                if out_tensor.dtype != torch.uint8:
                                    out_tensor = out_tensor.to(torch.uint8)

                                self.assertEqual(
                                    out_tensor.shape,
                                    out_pil_tensor.shape,
                                    msg="{}: {} vs {}".format(
                                        (img_size, r, dt, a, e, c), out_tensor.shape, out_pil_tensor.shape
                                    ))

                                num_diff_pixels = (out_tensor != out_pil_tensor).sum().item() / 3.0
                                ratio_diff_pixels = num_diff_pixels / out_tensor.shape[-1] / out_tensor.shape[-2]
                                # Tolerance : less than 3% of different pixels
                                self.assertLess(
706
                                    ratio_diff_pixels,
707
708
709
710
711
712
713
                                    0.03,
                                    msg="{}: {}\n{} vs \n{}".format(
                                        (img_size, r, dt, a, e, c, f),
                                        ratio_diff_pixels,
                                        out_tensor[0, :7, :7],
                                        out_pil_tensor[0, :7, :7]
                                    )
714
                                )
vfdev's avatar
vfdev committed
715

716
    def test_rotate(self):
vfdev's avatar
vfdev committed
717
718
719
        # Tests on square image
        scripted_rotate = torch.jit.script(F.rotate)

720
721
        data = [self._create_data(26, 26, device=self.device), self._create_data(32, 26, device=self.device)]
        for tensor, pil_img in data:
722
723
724
725
726
727
728
729

            img_size = pil_img.size
            centers = [
                None,
                (int(img_size[0] * 0.3), int(img_size[0] * 0.4)),
                [int(img_size[0] * 0.5), int(img_size[0] * 0.6)]
            ]

730
731
732
733
734
735
736
737
738
            for dt in [None, torch.float32, torch.float64, torch.float16]:

                if dt == torch.float16 and torch.device(self.device).type == "cpu":
                    # skip float16 on CPU case
                    continue

                if dt is not None:
                    tensor = tensor.to(dtype=dt)

739
740
741
742
743
744
745
746
                self._test_rotate_all_options(tensor, pil_img, scripted_rotate, centers)

                batch_tensors = self._create_data_batch(26, 36, num_samples=4, device=self.device)
                if dt is not None:
                    batch_tensors = batch_tensors.to(dtype=dt)

                center = (20, 22)
                self._test_fn_on_batch(
747
                    batch_tensors, F.rotate, angle=32, interpolation=NEAREST, expand=True, center=center
748
                )
749
750
751
752
753
754
755
756
        tensor, pil_img = data[0]
        # assert deprecation warning and non-BC
        with self.assertWarnsRegex(UserWarning, r"Argument resample is deprecated and will be removed"):
            res1 = F.rotate(tensor, 45, resample=2)
            res2 = F.rotate(tensor, 45, interpolation=BILINEAR)
            self.assertTrue(res1.equal(res2))

        # assert changed type warning
757
        with self.assertWarnsRegex(UserWarning, r"Argument interpolation should be of type InterpolationMode"):
758
759
760
            res1 = F.rotate(tensor, 45, interpolation=2)
            res2 = F.rotate(tensor, 45, interpolation=BILINEAR)
            self.assertTrue(res1.equal(res2))
761

762
    def _test_perspective(self, tensor, pil_img, scripted_transform, test_configs):
763
        dt = tensor.dtype
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
        for f in [None, [0, 0, 0], [1, 2, 3], [255, 255, 255], [1, ], (2.0, )]:
            for r in [NEAREST, ]:
                for spoints, epoints in test_configs:
                    f_pil = int(f[0]) if f is not None and len(f) == 1 else f
                    out_pil_img = F.perspective(pil_img, startpoints=spoints, endpoints=epoints, interpolation=r,
                                                fill=f_pil)
                    out_pil_tensor = torch.from_numpy(np.array(out_pil_img).transpose((2, 0, 1)))

                    for fn in [F.perspective, scripted_transform]:
                        out_tensor = fn(tensor, startpoints=spoints, endpoints=epoints, interpolation=r, fill=f).cpu()

                        if out_tensor.dtype != torch.uint8:
                            out_tensor = out_tensor.to(torch.uint8)

                        num_diff_pixels = (out_tensor != out_pil_tensor).sum().item() / 3.0
                        ratio_diff_pixels = num_diff_pixels / out_tensor.shape[-1] / out_tensor.shape[-2]
                        # Tolerance : less than 5% of different pixels
                        self.assertLess(
782
                            ratio_diff_pixels,
783
784
785
786
787
788
789
                            0.05,
                            msg="{}: {}\n{} vs \n{}".format(
                                (f, r, dt, spoints, epoints),
                                ratio_diff_pixels,
                                out_tensor[0, :7, :7],
                                out_pil_tensor[0, :7, :7]
                            )
790
                        )
vfdev's avatar
vfdev committed
791

792
    def test_perspective(self):
793
794
795

        from torchvision.transforms import RandomPerspective

796
        data = [self._create_data(26, 34, device=self.device), self._create_data(26, 26, device=self.device)]
797
        scripted_transform = torch.jit.script(F.perspective)
798

799
        for tensor, pil_img in data:
800
801
802
803
804
805
806
807
808
809
810

            test_configs = [
                [[[0, 0], [33, 0], [33, 25], [0, 25]], [[3, 2], [32, 3], [30, 24], [2, 25]]],
                [[[3, 2], [32, 3], [30, 24], [2, 25]], [[0, 0], [33, 0], [33, 25], [0, 25]]],
                [[[3, 2], [32, 3], [30, 24], [2, 25]], [[5, 5], [30, 3], [33, 19], [4, 25]]],
            ]
            n = 10
            test_configs += [
                RandomPerspective.get_params(pil_img.size[0], pil_img.size[1], i / n) for i in range(n)
            ]

811
812
813
814
815
816
817
818
819
            for dt in [None, torch.float32, torch.float64, torch.float16]:

                if dt == torch.float16 and torch.device(self.device).type == "cpu":
                    # skip float16 on CPU case
                    continue

                if dt is not None:
                    tensor = tensor.to(dtype=dt)

820
                self._test_perspective(tensor, pil_img, scripted_transform, test_configs)
821

822
823
824
                batch_tensors = self._create_data_batch(26, 36, num_samples=4, device=self.device)
                if dt is not None:
                    batch_tensors = batch_tensors.to(dtype=dt)
825

826
827
828
829
                # Ignore the equivalence between scripted and regular function on float16 cuda. The pixels at
                # the border may be entirely different due to small rounding errors.
                scripted_fn_atol = -1 if (dt == torch.float16 and self.device == "cuda") else 1e-8

830
831
                for spoints, epoints in test_configs:
                    self._test_fn_on_batch(
832
833
                        batch_tensors, F.perspective, scripted_fn_atol=scripted_fn_atol,
                        startpoints=spoints, endpoints=epoints, interpolation=NEAREST
834
                    )
835

836
837
838
        # assert changed type warning
        spoints = [[0, 0], [33, 0], [33, 25], [0, 25]]
        epoints = [[3, 2], [32, 3], [30, 24], [2, 25]]
839
        with self.assertWarnsRegex(UserWarning, r"Argument interpolation should be of type InterpolationMode"):
840
841
842
843
            res1 = F.perspective(tensor, startpoints=spoints, endpoints=epoints, interpolation=2)
            res2 = F.perspective(tensor, startpoints=spoints, endpoints=epoints, interpolation=BILINEAR)
            self.assertTrue(res1.equal(res2))

844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
    def test_gaussian_blur(self):
        small_image_tensor = torch.from_numpy(
            np.arange(3 * 10 * 12, dtype="uint8").reshape((10, 12, 3))
        ).permute(2, 0, 1).to(self.device)

        large_image_tensor = torch.from_numpy(
            np.arange(26 * 28, dtype="uint8").reshape((1, 26, 28))
        ).to(self.device)

        scripted_transform = torch.jit.script(F.gaussian_blur)

        # true_cv2_results = {
        #     # np_img = np.arange(3 * 10 * 12, dtype="uint8").reshape((10, 12, 3))
        #     # cv2.GaussianBlur(np_img, ksize=(3, 3), sigmaX=0.8)
        #     "3_3_0.8": ...
        #     # cv2.GaussianBlur(np_img, ksize=(3, 3), sigmaX=0.5)
        #     "3_3_0.5": ...
        #     # cv2.GaussianBlur(np_img, ksize=(3, 5), sigmaX=0.8)
        #     "3_5_0.8": ...
        #     # cv2.GaussianBlur(np_img, ksize=(3, 5), sigmaX=0.5)
        #     "3_5_0.5": ...
        #     # np_img2 = np.arange(26 * 28, dtype="uint8").reshape((26, 28))
        #     # cv2.GaussianBlur(np_img2, ksize=(23, 23), sigmaX=1.7)
        #     "23_23_1.7": ...
        # }
        p = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'assets', 'gaussian_blur_opencv_results.pt')
        true_cv2_results = torch.load(p)

        for tensor in [small_image_tensor, large_image_tensor]:

            for dt in [None, torch.float32, torch.float64, torch.float16]:
                if dt == torch.float16 and torch.device(self.device).type == "cpu":
                    # skip float16 on CPU case
                    continue

                if dt is not None:
                    tensor = tensor.to(dtype=dt)

                for ksize in [(3, 3), [3, 5], (23, 23)]:
                    for sigma in [[0.5, 0.5], (0.5, 0.5), (0.8, 0.8), (1.7, 1.7)]:

                        _ksize = (ksize, ksize) if isinstance(ksize, int) else ksize
                        _sigma = sigma[0] if sigma is not None else None
                        shape = tensor.shape
                        gt_key = "{}_{}_{}__{}_{}_{}".format(
                            shape[-2], shape[-1], shape[-3],
                            _ksize[0], _ksize[1], _sigma
                        )
                        if gt_key not in true_cv2_results:
                            continue

                        true_out = torch.tensor(
                            true_cv2_results[gt_key]
                        ).reshape(shape[-2], shape[-1], shape[-3]).permute(2, 0, 1).to(tensor)

                        for fn in [F.gaussian_blur, scripted_transform]:
                            out = fn(tensor, kernel_size=ksize, sigma=sigma)
                            self.assertEqual(true_out.shape, out.shape, msg="{}, {}".format(ksize, sigma))
                            self.assertLessEqual(
                                torch.max(true_out.float() - out.float()),
                                1.0,
                                msg="{}, {}".format(ksize, sigma)
                            )

908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
    def test_invert(self):
        self._test_adjust_fn(
            F.invert,
            F_pil.invert,
            F_t.invert,
            [{}],
            tol=1.0,
            agg_method="max"
        )

    def test_posterize(self):
        self._test_adjust_fn(
            F.posterize,
            F_pil.posterize,
            F_t.posterize,
            [{"bits": bits} for bits in range(0, 8)],
            tol=1.0,
            agg_method="max",
            dts=(None,)
        )

    def test_solarize(self):
        self._test_adjust_fn(
            F.solarize,
            F_pil.solarize,
            F_t.solarize,
            [{"threshold": threshold} for threshold in [0, 64, 128, 192, 255]],
            tol=1.0,
            agg_method="max",
            dts=(None,)
        )
        self._test_adjust_fn(
            F.solarize,
            lambda img, threshold: F_pil.solarize(img, 255 * threshold),
            F_t.solarize,
            [{"threshold": threshold} for threshold in [0.0, 0.25, 0.5, 0.75, 1.0]],
            tol=1.0,
            agg_method="max",
            dts=(torch.float32, torch.float64)
        )

    def test_adjust_sharpness(self):
        self._test_adjust_fn(
            F.adjust_sharpness,
            F_pil.adjust_sharpness,
            F_t.adjust_sharpness,
            [{"sharpness_factor": f} for f in [0.2, 0.5, 1.0, 1.5, 2.0]]
        )

    def test_autocontrast(self):
        self._test_adjust_fn(
            F.autocontrast,
            F_pil.autocontrast,
            F_t.autocontrast,
            [{}],
            tol=1.0,
            agg_method="max"
        )

    def test_equalize(self):
        torch.set_deterministic(False)
        self._test_adjust_fn(
            F.equalize,
            F_pil.equalize,
            F_t.equalize,
            [{}],
            tol=1.0,
            agg_method="max",
            dts=(None,)
        )

979

980
981
982
983
984
@unittest.skipIf(not torch.cuda.is_available(), reason="Skip if no CUDA device")
class CUDATester(Tester):

    def setUp(self):
        self.device = "cuda"
985

986
987
988
989
990
991
992
993
994
995
996
997
    def test_scale_channel(self):
        """Make sure that _scale_channel gives the same results on CPU and GPU as
        histc or bincount are used depending on the device.
        """
        # TODO: when # https://github.com/pytorch/pytorch/issues/53194 is fixed,
        # only use bincount and remove that test.
        size = (1_000,)
        img_chan = torch.randint(0, 256, size=size).to('cpu')
        scaled_cpu = F_t._scale_channel(img_chan)
        scaled_cuda = F_t._scale_channel(img_chan.to('cuda'))
        self.assertTrue(scaled_cpu.equal(scaled_cuda.to('cpu')))

998
999
1000

if __name__ == '__main__':
    unittest.main()