test_functional_tensor.py 44 KB
Newer Older
1
import os
2
import unittest
3
import colorsys
4
import math
5

vfdev's avatar
vfdev committed
6
7
8
9
10
11
import numpy as np

import torch
import torchvision.transforms.functional_tensor as F_t
import torchvision.transforms.functional_pil as F_pil
import torchvision.transforms.functional as F
12
from torchvision.transforms import InterpolationMode
13

14
from common_utils import TransformsTester
15

16
from typing import Dict, List, Sequence, Tuple
17

18

19
NEAREST, BILINEAR, BICUBIC = InterpolationMode.NEAREST, InterpolationMode.BILINEAR, InterpolationMode.BICUBIC
20
21


22
class Tester(TransformsTester):
vfdev's avatar
vfdev committed
23

24
25
26
    def setUp(self):
        self.device = "cpu"

27
28
29
30
31
32
33
34
35
36
37
38
    def _test_fn_on_batch(self, batch_tensors, fn, **fn_kwargs):
        transformed_batch = fn(batch_tensors, **fn_kwargs)
        for i in range(len(batch_tensors)):
            img_tensor = batch_tensors[i, ...]
            transformed_img = fn(img_tensor, **fn_kwargs)
            self.assertTrue(transformed_img.equal(transformed_batch[i, ...]))

        scripted_fn = torch.jit.script(fn)
        # scriptable function test
        s_transformed_batch = scripted_fn(batch_tensors, **fn_kwargs)
        self.assertTrue(transformed_batch.allclose(s_transformed_batch))

39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
    def test_assert_image_tensor(self):
        shape = (100,)
        tensor = torch.rand(*shape, dtype=torch.float, device=self.device)

        list_of_methods = [(F_t._get_image_size, (tensor, )), (F_t.vflip, (tensor, )),
                           (F_t.hflip, (tensor, )), (F_t.crop, (tensor, 1, 2, 4, 5)),
                           (F_t.adjust_brightness, (tensor, 0.)), (F_t.adjust_contrast, (tensor, 1.)),
                           (F_t.adjust_hue, (tensor, -0.5)), (F_t.adjust_saturation, (tensor, 2.)),
                           (F_t.center_crop, (tensor, [10, 11])), (F_t.five_crop, (tensor, [10, 11])),
                           (F_t.ten_crop, (tensor, [10, 11])), (F_t.pad, (tensor, [2, ], 2, "constant")),
                           (F_t.resize, (tensor, [10, 11])), (F_t.perspective, (tensor, [0.2, ])),
                           (F_t.gaussian_blur, (tensor, (2, 2), (0.7, 0.5))),
                           (F_t.invert, (tensor, )), (F_t.posterize, (tensor, 0)),
                           (F_t.solarize, (tensor, 0.3)), (F_t.adjust_sharpness, (tensor, 0.3)),
                           (F_t.autocontrast, (tensor, )), (F_t.equalize, (tensor, ))]

        for func, args in list_of_methods:
            with self.assertRaises(Exception) as context:
                func(*args)

            self.assertTrue('Tensor is not a torch image.' in str(context.exception))

61
    def test_vflip(self):
62
63
64
65
66
67
68
        script_vflip = torch.jit.script(F.vflip)

        img_tensor, pil_img = self._create_data(16, 18, device=self.device)
        vflipped_img = F.vflip(img_tensor)
        vflipped_pil_img = F.vflip(pil_img)
        self.compareTensorToPIL(vflipped_img, vflipped_pil_img)

69
70
        # scriptable function test
        vflipped_img_script = script_vflip(img_tensor)
71
72
73
74
        self.assertTrue(vflipped_img.equal(vflipped_img_script))

        batch_tensors = self._create_data_batch(16, 18, num_samples=4, device=self.device)
        self._test_fn_on_batch(batch_tensors, F.vflip)
75

76
    def test_hflip(self):
77
78
79
80
81
82
83
        script_hflip = torch.jit.script(F.hflip)

        img_tensor, pil_img = self._create_data(16, 18, device=self.device)
        hflipped_img = F.hflip(img_tensor)
        hflipped_pil_img = F.hflip(pil_img)
        self.compareTensorToPIL(hflipped_img, hflipped_pil_img)

84
85
        # scriptable function test
        hflipped_img_script = script_hflip(img_tensor)
86
87
88
89
        self.assertTrue(hflipped_img.equal(hflipped_img_script))

        batch_tensors = self._create_data_batch(16, 18, num_samples=4, device=self.device)
        self._test_fn_on_batch(batch_tensors, F.hflip)
90

91
    def test_crop(self):
92
        script_crop = torch.jit.script(F.crop)
93

94
        img_tensor, pil_img = self._create_data(16, 18, device=self.device)
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110

        test_configs = [
            (1, 2, 4, 5),   # crop inside top-left corner
            (2, 12, 3, 4),  # crop inside top-right corner
            (8, 3, 5, 6),   # crop inside bottom-left corner
            (8, 11, 4, 3),  # crop inside bottom-right corner
        ]

        for top, left, height, width in test_configs:
            pil_img_cropped = F.crop(pil_img, top, left, height, width)

            img_tensor_cropped = F.crop(img_tensor, top, left, height, width)
            self.compareTensorToPIL(img_tensor_cropped, pil_img_cropped)

            img_tensor_cropped = script_crop(img_tensor, top, left, height, width)
            self.compareTensorToPIL(img_tensor_cropped, pil_img_cropped)
ekka's avatar
ekka committed
111

112
113
114
            batch_tensors = self._create_data_batch(16, 18, num_samples=4, device=self.device)
            self._test_fn_on_batch(batch_tensors, F.crop, top=top, left=left, height=height, width=width)

115
    def test_hsv2rgb(self):
116
        scripted_fn = torch.jit.script(F_t._hsv2rgb)
117
        shape = (3, 100, 150)
118
119
120
121
        for _ in range(10):
            hsv_img = torch.rand(*shape, dtype=torch.float, device=self.device)
            rgb_img = F_t._hsv2rgb(hsv_img)
            ft_img = rgb_img.permute(1, 2, 0).flatten(0, 1)
122

123
124
125
126
            h, s, v, = hsv_img.unbind(0)
            h = h.flatten().cpu().numpy()
            s = s.flatten().cpu().numpy()
            v = v.flatten().cpu().numpy()
127
128
129
130

            rgb = []
            for h1, s1, v1 in zip(h, s, v):
                rgb.append(colorsys.hsv_to_rgb(h1, s1, v1))
131
            colorsys_img = torch.tensor(rgb, dtype=torch.float32, device=self.device)
132
133
134
            max_diff = (ft_img - colorsys_img).abs().max()
            self.assertLess(max_diff, 1e-5)

135
136
137
            s_rgb_img = scripted_fn(hsv_img)
            self.assertTrue(rgb_img.allclose(s_rgb_img))

138
139
140
        batch_tensors = self._create_data_batch(120, 100, num_samples=4, device=self.device).float()
        self._test_fn_on_batch(batch_tensors, F_t._hsv2rgb)

141
    def test_rgb2hsv(self):
142
        scripted_fn = torch.jit.script(F_t._rgb2hsv)
143
        shape = (3, 150, 100)
144
145
146
147
        for _ in range(10):
            rgb_img = torch.rand(*shape, dtype=torch.float, device=self.device)
            hsv_img = F_t._rgb2hsv(rgb_img)
            ft_hsv_img = hsv_img.permute(1, 2, 0).flatten(0, 1)
148

149
            r, g, b, = rgb_img.unbind(dim=-3)
150
151
152
            r = r.flatten().cpu().numpy()
            g = g.flatten().cpu().numpy()
            b = b.flatten().cpu().numpy()
153
154
155
156
157

            hsv = []
            for r1, g1, b1 in zip(r, g, b):
                hsv.append(colorsys.rgb_to_hsv(r1, g1, b1))

158
            colorsys_img = torch.tensor(hsv, dtype=torch.float32, device=self.device)
159

160
161
162
163
164
165
            ft_hsv_img_h, ft_hsv_img_sv = torch.split(ft_hsv_img, [1, 2], dim=1)
            colorsys_img_h, colorsys_img_sv = torch.split(colorsys_img, [1, 2], dim=1)

            max_diff_h = ((colorsys_img_h * 2 * math.pi).sin() - (ft_hsv_img_h * 2 * math.pi).sin()).abs().max()
            max_diff_sv = (colorsys_img_sv - ft_hsv_img_sv).abs().max()
            max_diff = max(max_diff_h, max_diff_sv)
166
167
            self.assertLess(max_diff, 1e-5)

168
169
170
            s_hsv_img = scripted_fn(rgb_img)
            self.assertTrue(hsv_img.allclose(s_hsv_img))

171
172
173
        batch_tensors = self._create_data_batch(120, 100, num_samples=4, device=self.device).float()
        self._test_fn_on_batch(batch_tensors, F_t._rgb2hsv)

174
    def test_rgb_to_grayscale(self):
175
176
        script_rgb_to_grayscale = torch.jit.script(F.rgb_to_grayscale)

177
        img_tensor, pil_img = self._create_data(32, 34, device=self.device)
178
179
180
181
182
183
184
185
186
187

        for num_output_channels in (3, 1):
            gray_pil_image = F.rgb_to_grayscale(pil_img, num_output_channels=num_output_channels)
            gray_tensor = F.rgb_to_grayscale(img_tensor, num_output_channels=num_output_channels)

            self.approxEqualTensorToPIL(gray_tensor.float(), gray_pil_image, tol=1.0 + 1e-10, agg_method="max")

            s_gray_tensor = script_rgb_to_grayscale(img_tensor, num_output_channels=num_output_channels)
            self.assertTrue(s_gray_tensor.equal(gray_tensor))

188
189
190
            batch_tensors = self._create_data_batch(16, 18, num_samples=4, device=self.device)
            self._test_fn_on_batch(batch_tensors, F.rgb_to_grayscale, num_output_channels=num_output_channels)

191
    def test_center_crop(self):
192
193
        script_center_crop = torch.jit.script(F.center_crop)

194
        img_tensor, pil_img = self._create_data(32, 34, device=self.device)
195
196
197
198
199
200
201
202

        cropped_pil_image = F.center_crop(pil_img, [10, 11])

        cropped_tensor = F.center_crop(img_tensor, [10, 11])
        self.compareTensorToPIL(cropped_tensor, cropped_pil_image)

        cropped_tensor = script_center_crop(img_tensor, [10, 11])
        self.compareTensorToPIL(cropped_tensor, cropped_pil_image)
203

204
205
206
        batch_tensors = self._create_data_batch(16, 18, num_samples=4, device=self.device)
        self._test_fn_on_batch(batch_tensors, F.center_crop, output_size=[10, 11])

207
    def test_five_crop(self):
208
209
        script_five_crop = torch.jit.script(F.five_crop)

210
        img_tensor, pil_img = self._create_data(32, 34, device=self.device)
211
212
213
214
215
216
217
218
219
220

        cropped_pil_images = F.five_crop(pil_img, [10, 11])

        cropped_tensors = F.five_crop(img_tensor, [10, 11])
        for i in range(5):
            self.compareTensorToPIL(cropped_tensors[i], cropped_pil_images[i])

        cropped_tensors = script_five_crop(img_tensor, [10, 11])
        for i in range(5):
            self.compareTensorToPIL(cropped_tensors[i], cropped_pil_images[i])
221

222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
        batch_tensors = self._create_data_batch(16, 18, num_samples=4, device=self.device)
        tuple_transformed_batches = F.five_crop(batch_tensors, [10, 11])
        for i in range(len(batch_tensors)):
            img_tensor = batch_tensors[i, ...]
            tuple_transformed_imgs = F.five_crop(img_tensor, [10, 11])
            self.assertEqual(len(tuple_transformed_imgs), len(tuple_transformed_batches))

            for j in range(len(tuple_transformed_imgs)):
                true_transformed_img = tuple_transformed_imgs[j]
                transformed_img = tuple_transformed_batches[j][i, ...]
                self.assertTrue(true_transformed_img.equal(transformed_img))

        # scriptable function test
        s_tuple_transformed_batches = script_five_crop(batch_tensors, [10, 11])
        for transformed_batch, s_transformed_batch in zip(tuple_transformed_batches, s_tuple_transformed_batches):
            self.assertTrue(transformed_batch.equal(s_transformed_batch))

239
    def test_ten_crop(self):
240
241
        script_ten_crop = torch.jit.script(F.ten_crop)

242
        img_tensor, pil_img = self._create_data(32, 34, device=self.device)
243
244
245
246
247
248
249
250
251
252

        cropped_pil_images = F.ten_crop(pil_img, [10, 11])

        cropped_tensors = F.ten_crop(img_tensor, [10, 11])
        for i in range(10):
            self.compareTensorToPIL(cropped_tensors[i], cropped_pil_images[i])

        cropped_tensors = script_ten_crop(img_tensor, [10, 11])
        for i in range(10):
            self.compareTensorToPIL(cropped_tensors[i], cropped_pil_images[i])
253

254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
        batch_tensors = self._create_data_batch(16, 18, num_samples=4, device=self.device)
        tuple_transformed_batches = F.ten_crop(batch_tensors, [10, 11])
        for i in range(len(batch_tensors)):
            img_tensor = batch_tensors[i, ...]
            tuple_transformed_imgs = F.ten_crop(img_tensor, [10, 11])
            self.assertEqual(len(tuple_transformed_imgs), len(tuple_transformed_batches))

            for j in range(len(tuple_transformed_imgs)):
                true_transformed_img = tuple_transformed_imgs[j]
                transformed_img = tuple_transformed_batches[j][i, ...]
                self.assertTrue(true_transformed_img.equal(transformed_img))

        # scriptable function test
        s_tuple_transformed_batches = script_ten_crop(batch_tensors, [10, 11])
        for transformed_batch, s_transformed_batch in zip(tuple_transformed_batches, s_tuple_transformed_batches):
            self.assertTrue(transformed_batch.equal(s_transformed_batch))

271
    def test_pad(self):
272
        script_fn = torch.jit.script(F.pad)
273
        tensor, pil_img = self._create_data(7, 8, device=self.device)
274
        batch_tensors = self._create_data_batch(16, 18, num_samples=4, device=self.device)
275

276
277
278
279
280
281
        for dt in [None, torch.float32, torch.float64, torch.float16]:

            if dt == torch.float16 and torch.device(self.device).type == "cpu":
                # skip float16 on CPU case
                continue

282
283
284
            if dt is not None:
                # This is a trivial cast to float of uint8 data to test all cases
                tensor = tensor.to(dt)
285
286
                batch_tensors = batch_tensors.to(dt)

287
288
289
290
291
292
293
            for pad in [2, [3, ], [0, 3], (3, 3), [4, 2, 4, 3]]:
                configs = [
                    {"padding_mode": "constant", "fill": 0},
                    {"padding_mode": "constant", "fill": 10},
                    {"padding_mode": "constant", "fill": 20},
                    {"padding_mode": "edge"},
                    {"padding_mode": "reflect"},
294
                    {"padding_mode": "symmetric"},
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
                ]
                for kwargs in configs:
                    pad_tensor = F_t.pad(tensor, pad, **kwargs)
                    pad_pil_img = F_pil.pad(pil_img, pad, **kwargs)

                    pad_tensor_8b = pad_tensor
                    # we need to cast to uint8 to compare with PIL image
                    if pad_tensor_8b.dtype != torch.uint8:
                        pad_tensor_8b = pad_tensor_8b.to(torch.uint8)

                    self.compareTensorToPIL(pad_tensor_8b, pad_pil_img, msg="{}, {}".format(pad, kwargs))

                    if isinstance(pad, int):
                        script_pad = [pad, ]
                    else:
                        script_pad = pad
                    pad_tensor_script = script_fn(tensor, script_pad, **kwargs)
                    self.assertTrue(pad_tensor.equal(pad_tensor_script), msg="{}, {}".format(pad, kwargs))
313

314
315
                    self._test_fn_on_batch(batch_tensors, F.pad, padding=script_pad, **kwargs)

316
317
    def _test_adjust_fn(self, fn, fn_pil, fn_t, configs, tol=2.0 + 1e-10, agg_method="max",
                        dts=(None, torch.float32, torch.float64)):
vfdev's avatar
vfdev committed
318
319
320
        script_fn = torch.jit.script(fn)
        torch.manual_seed(15)
        tensor, pil_img = self._create_data(26, 34, device=self.device)
321
        batch_tensors = self._create_data_batch(16, 18, num_samples=4, device=self.device)
vfdev's avatar
vfdev committed
322

323
        for dt in dts:
324
325
326

            if dt is not None:
                tensor = F.convert_image_dtype(tensor, dt)
327
                batch_tensors = F.convert_image_dtype(batch_tensors, dt)
328

vfdev's avatar
vfdev committed
329
330
331
332
333
334
335
            for config in configs:
                adjusted_tensor = fn_t(tensor, **config)
                adjusted_pil = fn_pil(pil_img, **config)
                scripted_result = script_fn(tensor, **config)
                msg = "{}, {}".format(dt, config)
                self.assertEqual(adjusted_tensor.dtype, scripted_result.dtype, msg=msg)
                self.assertEqual(adjusted_tensor.size()[1:], adjusted_pil.size[::-1], msg=msg)
336
337

                rbg_tensor = adjusted_tensor
vfdev's avatar
vfdev committed
338

339
340
341
                if adjusted_tensor.dtype != torch.uint8:
                    rbg_tensor = F.convert_image_dtype(adjusted_tensor, torch.uint8)

vfdev's avatar
vfdev committed
342
343
                # Check that max difference does not exceed 2 in [0, 255] range
                # Exact matching is not possible due to incompatibility convert_image_dtype and PIL results
344
345
346
347
348
349
                self.approxEqualTensorToPIL(rbg_tensor.float(), adjusted_pil, tol=tol, msg=msg, agg_method=agg_method)

                atol = 1e-6
                if adjusted_tensor.dtype == torch.uint8 and "cuda" in torch.device(self.device).type:
                    atol = 1.0
                self.assertTrue(adjusted_tensor.allclose(scripted_result, atol=atol), msg=msg)
vfdev's avatar
vfdev committed
350

351
352
                self._test_fn_on_batch(batch_tensors, fn, **config)

vfdev's avatar
vfdev committed
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
    def test_adjust_brightness(self):
        self._test_adjust_fn(
            F.adjust_brightness,
            F_pil.adjust_brightness,
            F_t.adjust_brightness,
            [{"brightness_factor": f} for f in [0.1, 0.5, 1.0, 1.34, 2.5]]
        )

    def test_adjust_contrast(self):
        self._test_adjust_fn(
            F.adjust_contrast,
            F_pil.adjust_contrast,
            F_t.adjust_contrast,
            [{"contrast_factor": f} for f in [0.2, 0.5, 1.0, 1.5, 2.0]]
        )

    def test_adjust_saturation(self):
        self._test_adjust_fn(
            F.adjust_saturation,
            F_pil.adjust_saturation,
            F_t.adjust_saturation,
            [{"saturation_factor": f} for f in [0.5, 0.75, 1.0, 1.5, 2.0]]
        )
376

377
378
379
380
381
382
    def test_adjust_hue(self):
        self._test_adjust_fn(
            F.adjust_hue,
            F_pil.adjust_hue,
            F_t.adjust_hue,
            [{"hue_factor": f} for f in [-0.45, -0.25, 0.0, 0.25, 0.45]],
vfdev's avatar
vfdev committed
383
384
            tol=16.1,
            agg_method="max"
385
386
        )

vfdev's avatar
vfdev committed
387
388
389
390
391
392
393
    def test_adjust_gamma(self):
        self._test_adjust_fn(
            F.adjust_gamma,
            F_pil.adjust_gamma,
            F_t.adjust_gamma,
            [{"gamma": g1, "gain": g2} for g1, g2 in zip([0.8, 1.0, 1.2], [0.7, 1.0, 1.3])]
        )
394

395
    def test_resize(self):
396
        script_fn = torch.jit.script(F.resize)
397
        tensor, pil_img = self._create_data(26, 36, device=self.device)
398
        batch_tensors = self._create_data_batch(16, 18, num_samples=4, device=self.device)
vfdev's avatar
vfdev committed
399

400
401
402
403
404
405
        for dt in [None, torch.float32, torch.float64, torch.float16]:

            if dt == torch.float16 and torch.device(self.device).type == "cpu":
                # skip float16 on CPU case
                continue

vfdev's avatar
vfdev committed
406
407
408
            if dt is not None:
                # This is a trivial cast to float of uint8 data to test all cases
                tensor = tensor.to(dt)
409
410
                batch_tensors = batch_tensors.to(dt)

411
            for size in [32, 26, [32, ], [32, 32], (32, 32), [26, 35]]:
412
413
414
415
416
417
418
419
420
421
                for max_size in (None, 33, 40, 1000):
                    if max_size is not None and isinstance(size, Sequence) and len(size) != 1:
                        continue  # unsupported, see assertRaises below
                    for interpolation in [BILINEAR, BICUBIC, NEAREST]:
                        resized_tensor = F.resize(tensor, size=size, interpolation=interpolation, max_size=max_size)
                        resized_pil_img = F.resize(pil_img, size=size, interpolation=interpolation, max_size=max_size)

                        self.assertEqual(
                            resized_tensor.size()[1:], resized_pil_img.size[::-1],
                            msg="{}, {}".format(size, interpolation)
vfdev's avatar
vfdev committed
422
423
                        )

424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
                        if interpolation not in [NEAREST, ]:
                            # We can not check values if mode = NEAREST, as results are different
                            # E.g. resized_tensor  = [[a, a, b, c, d, d, e, ...]]
                            # E.g. resized_pil_img = [[a, b, c, c, d, e, f, ...]]
                            resized_tensor_f = resized_tensor
                            # we need to cast to uint8 to compare with PIL image
                            if resized_tensor_f.dtype == torch.uint8:
                                resized_tensor_f = resized_tensor_f.to(torch.float)

                            # Pay attention to high tolerance for MAE
                            self.approxEqualTensorToPIL(
                                resized_tensor_f, resized_pil_img, tol=8.0, msg="{}, {}".format(size, interpolation)
                            )

                        if isinstance(size, int):
                            script_size = [size, ]
                        else:
                            script_size = size

                        resize_result = script_fn(tensor, size=script_size, interpolation=interpolation,
                                                  max_size=max_size)
                        self.assertTrue(resized_tensor.equal(resize_result), msg="{}, {}".format(size, interpolation))

                        self._test_fn_on_batch(
                            batch_tensors, F.resize, size=script_size, interpolation=interpolation, max_size=max_size
                        )
450

451
        # assert changed type warning
452
        with self.assertWarnsRegex(UserWarning, r"Argument interpolation should be of type InterpolationMode"):
453
454
455
456
            res1 = F.resize(tensor, size=32, interpolation=2)
            res2 = F.resize(tensor, size=32, interpolation=BILINEAR)
            self.assertTrue(res1.equal(res2))

457
458
459
460
461
462
463
        for img in (tensor, pil_img):
            exp_msg = "max_size should only be passed if size specifies the length of the smaller edge"
            with self.assertRaisesRegex(ValueError, exp_msg):
                F.resize(img, size=(32, 34), max_size=35)
            with self.assertRaisesRegex(ValueError, "max_size = 32 must be strictly greater"):
                F.resize(img, size=32, max_size=32)

464
    def test_resized_crop(self):
465
466
        # test values of F.resized_crop in several cases:
        # 1) resize to the same size, crop to the same size => should be identity
467
        tensor, _ = self._create_data(26, 36, device=self.device)
468
469
470

        for mode in [NEAREST, BILINEAR, BICUBIC]:
            out_tensor = F.resized_crop(tensor, top=0, left=0, height=26, width=36, size=[26, 36], interpolation=mode)
471
472
473
            self.assertTrue(tensor.equal(out_tensor), msg="{} vs {}".format(out_tensor[0, :5, :5], tensor[0, :5, :5]))

        # 2) resize by half and crop a TL corner
474
        tensor, _ = self._create_data(26, 36, device=self.device)
475
        out_tensor = F.resized_crop(tensor, top=0, left=0, height=20, width=30, size=[10, 15], interpolation=NEAREST)
476
477
478
479
480
481
        expected_out_tensor = tensor[:, :20:2, :30:2]
        self.assertTrue(
            expected_out_tensor.equal(out_tensor),
            msg="{} vs {}".format(expected_out_tensor[0, :10, :10], out_tensor[0, :10, :10])
        )

482
483
        batch_tensors = self._create_data_batch(26, 36, num_samples=4, device=self.device)
        self._test_fn_on_batch(
484
            batch_tensors, F.resized_crop, top=1, left=2, height=20, width=30, size=[10, 15], interpolation=NEAREST
485
486
        )

487
488
    def _test_affine_identity_map(self, tensor, scripted_affine):
        # 1) identity map
489
        out_tensor = F.affine(tensor, angle=0, translate=[0, 0], scale=1.0, shear=[0.0, 0.0], interpolation=NEAREST)
vfdev's avatar
vfdev committed
490

491
492
493
        self.assertTrue(
            tensor.equal(out_tensor), msg="{} vs {}".format(out_tensor[0, :5, :5], tensor[0, :5, :5])
        )
494
495
496
        out_tensor = scripted_affine(
            tensor, angle=0, translate=[0, 0], scale=1.0, shear=[0.0, 0.0], interpolation=NEAREST
        )
497
498
499
        self.assertTrue(
            tensor.equal(out_tensor), msg="{} vs {}".format(out_tensor[0, :5, :5], tensor[0, :5, :5])
        )
500

501
502
503
504
505
506
507
508
509
510
511
512
513
    def _test_affine_square_rotations(self, tensor, pil_img, scripted_affine):
        # 2) Test rotation
        test_configs = [
            (90, torch.rot90(tensor, k=1, dims=(-1, -2))),
            (45, None),
            (30, None),
            (-30, None),
            (-45, None),
            (-90, torch.rot90(tensor, k=-1, dims=(-1, -2))),
            (180, torch.rot90(tensor, k=2, dims=(-1, -2))),
        ]
        for a, true_tensor in test_configs:
            out_pil_img = F.affine(
514
                pil_img, angle=a, translate=[0, 0], scale=1.0, shear=[0.0, 0.0], interpolation=NEAREST
515
            )
516
517
518
519
            out_pil_tensor = torch.from_numpy(np.array(out_pil_img).transpose((2, 0, 1))).to(self.device)

            for fn in [F.affine, scripted_affine]:
                out_tensor = fn(
520
                    tensor, angle=a, translate=[0, 0], scale=1.0, shear=[0.0, 0.0], interpolation=NEAREST
521
522
523
524
525
526
                )
                if true_tensor is not None:
                    self.assertTrue(
                        true_tensor.equal(out_tensor),
                        msg="{}\n{} vs \n{}".format(a, out_tensor[0, :5, :5], true_tensor[0, :5, :5])
                    )
527

528
529
530
531
532
533
534
535
536
537
538
                if out_tensor.dtype != torch.uint8:
                    out_tensor = out_tensor.to(torch.uint8)

                num_diff_pixels = (out_tensor != out_pil_tensor).sum().item() / 3.0
                ratio_diff_pixels = num_diff_pixels / out_tensor.shape[-1] / out_tensor.shape[-2]
                # Tolerance : less than 6% of different pixels
                self.assertLess(
                    ratio_diff_pixels,
                    0.06,
                    msg="{}\n{} vs \n{}".format(
                        ratio_diff_pixels, out_tensor[0, :7, :7], out_pil_tensor[0, :7, :7]
539
                    )
540
                )
541

542
543
544
545
546
    def _test_affine_rect_rotations(self, tensor, pil_img, scripted_affine):
        test_configs = [
            90, 45, 15, -30, -60, -120
        ]
        for a in test_configs:
547

548
            out_pil_img = F.affine(
549
                pil_img, angle=a, translate=[0, 0], scale=1.0, shear=[0.0, 0.0], interpolation=NEAREST
550
551
552
553
554
            )
            out_pil_tensor = torch.from_numpy(np.array(out_pil_img).transpose((2, 0, 1)))

            for fn in [F.affine, scripted_affine]:
                out_tensor = fn(
555
                    tensor, angle=a, translate=[0, 0], scale=1.0, shear=[0.0, 0.0], interpolation=NEAREST
556
557
558
559
560
561
562
563
564
565
566
567
568
                ).cpu()

                if out_tensor.dtype != torch.uint8:
                    out_tensor = out_tensor.to(torch.uint8)

                num_diff_pixels = (out_tensor != out_pil_tensor).sum().item() / 3.0
                ratio_diff_pixels = num_diff_pixels / out_tensor.shape[-1] / out_tensor.shape[-2]
                # Tolerance : less than 3% of different pixels
                self.assertLess(
                    ratio_diff_pixels,
                    0.03,
                    msg="{}: {}\n{} vs \n{}".format(
                        a, ratio_diff_pixels, out_tensor[0, :7, :7], out_pil_tensor[0, :7, :7]
569
                    )
570
                )
571

572
573
574
575
576
577
    def _test_affine_translations(self, tensor, pil_img, scripted_affine):
        # 3) Test translation
        test_configs = [
            [10, 12], (-12, -13)
        ]
        for t in test_configs:
578

579
            out_pil_img = F.affine(pil_img, angle=0, translate=t, scale=1.0, shear=[0.0, 0.0], interpolation=NEAREST)
580

581
            for fn in [F.affine, scripted_affine]:
582
                out_tensor = fn(tensor, angle=0, translate=t, scale=1.0, shear=[0.0, 0.0], interpolation=NEAREST)
583

584
585
                if out_tensor.dtype != torch.uint8:
                    out_tensor = out_tensor.to(torch.uint8)
586

587
588
589
590
591
                self.compareTensorToPIL(out_tensor, out_pil_img)

    def _test_affine_all_ops(self, tensor, pil_img, scripted_affine):
        # 4) Test rotation + translation + scale + share
        test_configs = [
592
593
594
595
596
597
598
599
600
601
            (45.5, [5, 6], 1.0, [0.0, 0.0], None),
            (33, (5, -4), 1.0, [0.0, 0.0], [0, 0, 0]),
            (45, [-5, 4], 1.2, [0.0, 0.0], (1, 2, 3)),
            (33, (-4, -8), 2.0, [0.0, 0.0], [255, 255, 255]),
            (85, (10, -10), 0.7, [0.0, 0.0], [1, ]),
            (0, [0, 0], 1.0, [35.0, ], (2.0, )),
            (-25, [0, 0], 1.2, [0.0, 15.0], None),
            (-45, [-10, 0], 0.7, [2.0, 5.0], None),
            (-45, [-10, -10], 1.2, [4.0, 5.0], None),
            (-90, [0, 0], 1.0, [0.0, 0.0], None),
602
        ]
603
        for r in [NEAREST, ]:
604
605
606
            for a, t, s, sh, f in test_configs:
                f_pil = int(f[0]) if f is not None and len(f) == 1 else f
                out_pil_img = F.affine(pil_img, angle=a, translate=t, scale=s, shear=sh, interpolation=r, fill=f_pil)
607
608
609
                out_pil_tensor = torch.from_numpy(np.array(out_pil_img).transpose((2, 0, 1)))

                for fn in [F.affine, scripted_affine]:
610
                    out_tensor = fn(tensor, angle=a, translate=t, scale=s, shear=sh, interpolation=r, fill=f).cpu()
611
612
613
614
615
616
617
618
619
620
621
622

                    if out_tensor.dtype != torch.uint8:
                        out_tensor = out_tensor.to(torch.uint8)

                    num_diff_pixels = (out_tensor != out_pil_tensor).sum().item() / 3.0
                    ratio_diff_pixels = num_diff_pixels / out_tensor.shape[-1] / out_tensor.shape[-2]
                    # Tolerance : less than 5% (cpu), 6% (cuda) of different pixels
                    tol = 0.06 if self.device == "cuda" else 0.05
                    self.assertLess(
                        ratio_diff_pixels,
                        tol,
                        msg="{}: {}\n{} vs \n{}".format(
623
                            (r, a, t, s, sh, f), ratio_diff_pixels, out_tensor[0, :7, :7], out_pil_tensor[0, :7, :7]
vfdev's avatar
vfdev committed
624
                        )
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
                    )

    def test_affine(self):
        # Tests on square and rectangular images
        scripted_affine = torch.jit.script(F.affine)

        data = [self._create_data(26, 26, device=self.device), self._create_data(32, 26, device=self.device)]
        for tensor, pil_img in data:

            for dt in [None, torch.float32, torch.float64, torch.float16]:

                if dt == torch.float16 and torch.device(self.device).type == "cpu":
                    # skip float16 on CPU case
                    continue

                if dt is not None:
                    tensor = tensor.to(dtype=dt)

                self._test_affine_identity_map(tensor, scripted_affine)
                if pil_img.size[0] == pil_img.size[1]:
                    self._test_affine_square_rotations(tensor, pil_img, scripted_affine)
                else:
                    self._test_affine_rect_rotations(tensor, pil_img, scripted_affine)
                self._test_affine_translations(tensor, pil_img, scripted_affine)
649
650
651
652
653
654
655
656
657
658
                self._test_affine_all_ops(tensor, pil_img, scripted_affine)

                batch_tensors = self._create_data_batch(26, 36, num_samples=4, device=self.device)
                if dt is not None:
                    batch_tensors = batch_tensors.to(dtype=dt)

                self._test_fn_on_batch(
                    batch_tensors, F.affine, angle=-43, translate=[-3, 4], scale=1.2, shear=[4.0, 5.0]
                )

659
660
661
662
663
664
665
666
        tensor, pil_img = data[0]
        # assert deprecation warning and non-BC
        with self.assertWarnsRegex(UserWarning, r"Argument resample is deprecated and will be removed"):
            res1 = F.affine(tensor, 45, translate=[0, 0], scale=1.0, shear=[0.0, 0.0], resample=2)
            res2 = F.affine(tensor, 45, translate=[0, 0], scale=1.0, shear=[0.0, 0.0], interpolation=BILINEAR)
            self.assertTrue(res1.equal(res2))

        # assert changed type warning
667
        with self.assertWarnsRegex(UserWarning, r"Argument interpolation should be of type InterpolationMode"):
668
669
670
671
672
673
674
675
676
            res1 = F.affine(tensor, 45, translate=[0, 0], scale=1.0, shear=[0.0, 0.0], interpolation=2)
            res2 = F.affine(tensor, 45, translate=[0, 0], scale=1.0, shear=[0.0, 0.0], interpolation=BILINEAR)
            self.assertTrue(res1.equal(res2))

        with self.assertWarnsRegex(UserWarning, r"Argument fillcolor is deprecated and will be removed"):
            res1 = F.affine(pil_img, 45, translate=[0, 0], scale=1.0, shear=[0.0, 0.0], fillcolor=10)
            res2 = F.affine(pil_img, 45, translate=[0, 0], scale=1.0, shear=[0.0, 0.0], fill=10)
            self.assertEqual(res1, res2)

677
678
679
    def _test_rotate_all_options(self, tensor, pil_img, scripted_rotate, centers):
        img_size = pil_img.size
        dt = tensor.dtype
680
        for r in [NEAREST, ]:
681
682
683
            for a in range(-180, 180, 17):
                for e in [True, False]:
                    for c in centers:
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
                        for f in [None, [0, 0, 0], (1, 2, 3), [255, 255, 255], [1, ], (2.0, )]:
                            f_pil = int(f[0]) if f is not None and len(f) == 1 else f
                            out_pil_img = F.rotate(pil_img, angle=a, interpolation=r, expand=e, center=c, fill=f_pil)
                            out_pil_tensor = torch.from_numpy(np.array(out_pil_img).transpose((2, 0, 1)))
                            for fn in [F.rotate, scripted_rotate]:
                                out_tensor = fn(tensor, angle=a, interpolation=r, expand=e, center=c, fill=f).cpu()

                                if out_tensor.dtype != torch.uint8:
                                    out_tensor = out_tensor.to(torch.uint8)

                                self.assertEqual(
                                    out_tensor.shape,
                                    out_pil_tensor.shape,
                                    msg="{}: {} vs {}".format(
                                        (img_size, r, dt, a, e, c), out_tensor.shape, out_pil_tensor.shape
                                    ))

                                num_diff_pixels = (out_tensor != out_pil_tensor).sum().item() / 3.0
                                ratio_diff_pixels = num_diff_pixels / out_tensor.shape[-1] / out_tensor.shape[-2]
                                # Tolerance : less than 3% of different pixels
                                self.assertLess(
705
                                    ratio_diff_pixels,
706
707
708
709
710
711
712
                                    0.03,
                                    msg="{}: {}\n{} vs \n{}".format(
                                        (img_size, r, dt, a, e, c, f),
                                        ratio_diff_pixels,
                                        out_tensor[0, :7, :7],
                                        out_pil_tensor[0, :7, :7]
                                    )
713
                                )
vfdev's avatar
vfdev committed
714

715
    def test_rotate(self):
vfdev's avatar
vfdev committed
716
717
718
        # Tests on square image
        scripted_rotate = torch.jit.script(F.rotate)

719
720
        data = [self._create_data(26, 26, device=self.device), self._create_data(32, 26, device=self.device)]
        for tensor, pil_img in data:
721
722
723
724
725
726
727
728

            img_size = pil_img.size
            centers = [
                None,
                (int(img_size[0] * 0.3), int(img_size[0] * 0.4)),
                [int(img_size[0] * 0.5), int(img_size[0] * 0.6)]
            ]

729
730
731
732
733
734
735
736
737
            for dt in [None, torch.float32, torch.float64, torch.float16]:

                if dt == torch.float16 and torch.device(self.device).type == "cpu":
                    # skip float16 on CPU case
                    continue

                if dt is not None:
                    tensor = tensor.to(dtype=dt)

738
739
740
741
742
743
744
745
                self._test_rotate_all_options(tensor, pil_img, scripted_rotate, centers)

                batch_tensors = self._create_data_batch(26, 36, num_samples=4, device=self.device)
                if dt is not None:
                    batch_tensors = batch_tensors.to(dtype=dt)

                center = (20, 22)
                self._test_fn_on_batch(
746
                    batch_tensors, F.rotate, angle=32, interpolation=NEAREST, expand=True, center=center
747
                )
748
749
750
751
752
753
754
755
        tensor, pil_img = data[0]
        # assert deprecation warning and non-BC
        with self.assertWarnsRegex(UserWarning, r"Argument resample is deprecated and will be removed"):
            res1 = F.rotate(tensor, 45, resample=2)
            res2 = F.rotate(tensor, 45, interpolation=BILINEAR)
            self.assertTrue(res1.equal(res2))

        # assert changed type warning
756
        with self.assertWarnsRegex(UserWarning, r"Argument interpolation should be of type InterpolationMode"):
757
758
759
            res1 = F.rotate(tensor, 45, interpolation=2)
            res2 = F.rotate(tensor, 45, interpolation=BILINEAR)
            self.assertTrue(res1.equal(res2))
760

761
    def _test_perspective(self, tensor, pil_img, scripted_transform, test_configs):
762
        dt = tensor.dtype
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
        for f in [None, [0, 0, 0], [1, 2, 3], [255, 255, 255], [1, ], (2.0, )]:
            for r in [NEAREST, ]:
                for spoints, epoints in test_configs:
                    f_pil = int(f[0]) if f is not None and len(f) == 1 else f
                    out_pil_img = F.perspective(pil_img, startpoints=spoints, endpoints=epoints, interpolation=r,
                                                fill=f_pil)
                    out_pil_tensor = torch.from_numpy(np.array(out_pil_img).transpose((2, 0, 1)))

                    for fn in [F.perspective, scripted_transform]:
                        out_tensor = fn(tensor, startpoints=spoints, endpoints=epoints, interpolation=r, fill=f).cpu()

                        if out_tensor.dtype != torch.uint8:
                            out_tensor = out_tensor.to(torch.uint8)

                        num_diff_pixels = (out_tensor != out_pil_tensor).sum().item() / 3.0
                        ratio_diff_pixels = num_diff_pixels / out_tensor.shape[-1] / out_tensor.shape[-2]
                        # Tolerance : less than 5% of different pixels
                        self.assertLess(
781
                            ratio_diff_pixels,
782
783
784
785
786
787
788
                            0.05,
                            msg="{}: {}\n{} vs \n{}".format(
                                (f, r, dt, spoints, epoints),
                                ratio_diff_pixels,
                                out_tensor[0, :7, :7],
                                out_pil_tensor[0, :7, :7]
                            )
789
                        )
vfdev's avatar
vfdev committed
790

791
    def test_perspective(self):
792
793
794

        from torchvision.transforms import RandomPerspective

795
        data = [self._create_data(26, 34, device=self.device), self._create_data(26, 26, device=self.device)]
796
        scripted_transform = torch.jit.script(F.perspective)
797

798
        for tensor, pil_img in data:
799
800
801
802
803
804
805
806
807
808
809

            test_configs = [
                [[[0, 0], [33, 0], [33, 25], [0, 25]], [[3, 2], [32, 3], [30, 24], [2, 25]]],
                [[[3, 2], [32, 3], [30, 24], [2, 25]], [[0, 0], [33, 0], [33, 25], [0, 25]]],
                [[[3, 2], [32, 3], [30, 24], [2, 25]], [[5, 5], [30, 3], [33, 19], [4, 25]]],
            ]
            n = 10
            test_configs += [
                RandomPerspective.get_params(pil_img.size[0], pil_img.size[1], i / n) for i in range(n)
            ]

810
811
812
813
814
815
816
817
818
            for dt in [None, torch.float32, torch.float64, torch.float16]:

                if dt == torch.float16 and torch.device(self.device).type == "cpu":
                    # skip float16 on CPU case
                    continue

                if dt is not None:
                    tensor = tensor.to(dtype=dt)

819
                self._test_perspective(tensor, pil_img, scripted_transform, test_configs)
820

821
822
823
                batch_tensors = self._create_data_batch(26, 36, num_samples=4, device=self.device)
                if dt is not None:
                    batch_tensors = batch_tensors.to(dtype=dt)
824

825
826
                for spoints, epoints in test_configs:
                    self._test_fn_on_batch(
827
                        batch_tensors, F.perspective, startpoints=spoints, endpoints=epoints, interpolation=NEAREST
828
                    )
829

830
831
832
        # assert changed type warning
        spoints = [[0, 0], [33, 0], [33, 25], [0, 25]]
        epoints = [[3, 2], [32, 3], [30, 24], [2, 25]]
833
        with self.assertWarnsRegex(UserWarning, r"Argument interpolation should be of type InterpolationMode"):
834
835
836
837
            res1 = F.perspective(tensor, startpoints=spoints, endpoints=epoints, interpolation=2)
            res2 = F.perspective(tensor, startpoints=spoints, endpoints=epoints, interpolation=BILINEAR)
            self.assertTrue(res1.equal(res2))

838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
    def test_gaussian_blur(self):
        small_image_tensor = torch.from_numpy(
            np.arange(3 * 10 * 12, dtype="uint8").reshape((10, 12, 3))
        ).permute(2, 0, 1).to(self.device)

        large_image_tensor = torch.from_numpy(
            np.arange(26 * 28, dtype="uint8").reshape((1, 26, 28))
        ).to(self.device)

        scripted_transform = torch.jit.script(F.gaussian_blur)

        # true_cv2_results = {
        #     # np_img = np.arange(3 * 10 * 12, dtype="uint8").reshape((10, 12, 3))
        #     # cv2.GaussianBlur(np_img, ksize=(3, 3), sigmaX=0.8)
        #     "3_3_0.8": ...
        #     # cv2.GaussianBlur(np_img, ksize=(3, 3), sigmaX=0.5)
        #     "3_3_0.5": ...
        #     # cv2.GaussianBlur(np_img, ksize=(3, 5), sigmaX=0.8)
        #     "3_5_0.8": ...
        #     # cv2.GaussianBlur(np_img, ksize=(3, 5), sigmaX=0.5)
        #     "3_5_0.5": ...
        #     # np_img2 = np.arange(26 * 28, dtype="uint8").reshape((26, 28))
        #     # cv2.GaussianBlur(np_img2, ksize=(23, 23), sigmaX=1.7)
        #     "23_23_1.7": ...
        # }
        p = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'assets', 'gaussian_blur_opencv_results.pt')
        true_cv2_results = torch.load(p)

        for tensor in [small_image_tensor, large_image_tensor]:

            for dt in [None, torch.float32, torch.float64, torch.float16]:
                if dt == torch.float16 and torch.device(self.device).type == "cpu":
                    # skip float16 on CPU case
                    continue

                if dt is not None:
                    tensor = tensor.to(dtype=dt)

                for ksize in [(3, 3), [3, 5], (23, 23)]:
                    for sigma in [[0.5, 0.5], (0.5, 0.5), (0.8, 0.8), (1.7, 1.7)]:

                        _ksize = (ksize, ksize) if isinstance(ksize, int) else ksize
                        _sigma = sigma[0] if sigma is not None else None
                        shape = tensor.shape
                        gt_key = "{}_{}_{}__{}_{}_{}".format(
                            shape[-2], shape[-1], shape[-3],
                            _ksize[0], _ksize[1], _sigma
                        )
                        if gt_key not in true_cv2_results:
                            continue

                        true_out = torch.tensor(
                            true_cv2_results[gt_key]
                        ).reshape(shape[-2], shape[-1], shape[-3]).permute(2, 0, 1).to(tensor)

                        for fn in [F.gaussian_blur, scripted_transform]:
                            out = fn(tensor, kernel_size=ksize, sigma=sigma)
                            self.assertEqual(true_out.shape, out.shape, msg="{}, {}".format(ksize, sigma))
                            self.assertLessEqual(
                                torch.max(true_out.float() - out.float()),
                                1.0,
                                msg="{}, {}".format(ksize, sigma)
                            )

902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
    def test_invert(self):
        self._test_adjust_fn(
            F.invert,
            F_pil.invert,
            F_t.invert,
            [{}],
            tol=1.0,
            agg_method="max"
        )

    def test_posterize(self):
        self._test_adjust_fn(
            F.posterize,
            F_pil.posterize,
            F_t.posterize,
            [{"bits": bits} for bits in range(0, 8)],
            tol=1.0,
            agg_method="max",
            dts=(None,)
        )

    def test_solarize(self):
        self._test_adjust_fn(
            F.solarize,
            F_pil.solarize,
            F_t.solarize,
            [{"threshold": threshold} for threshold in [0, 64, 128, 192, 255]],
            tol=1.0,
            agg_method="max",
            dts=(None,)
        )
        self._test_adjust_fn(
            F.solarize,
            lambda img, threshold: F_pil.solarize(img, 255 * threshold),
            F_t.solarize,
            [{"threshold": threshold} for threshold in [0.0, 0.25, 0.5, 0.75, 1.0]],
            tol=1.0,
            agg_method="max",
            dts=(torch.float32, torch.float64)
        )

    def test_adjust_sharpness(self):
        self._test_adjust_fn(
            F.adjust_sharpness,
            F_pil.adjust_sharpness,
            F_t.adjust_sharpness,
            [{"sharpness_factor": f} for f in [0.2, 0.5, 1.0, 1.5, 2.0]]
        )

    def test_autocontrast(self):
        self._test_adjust_fn(
            F.autocontrast,
            F_pil.autocontrast,
            F_t.autocontrast,
            [{}],
            tol=1.0,
            agg_method="max"
        )

    def test_equalize(self):
        torch.set_deterministic(False)
        self._test_adjust_fn(
            F.equalize,
            F_pil.equalize,
            F_t.equalize,
            [{}],
            tol=1.0,
            agg_method="max",
            dts=(None,)
        )

973

974
975
976
977
978
@unittest.skipIf(not torch.cuda.is_available(), reason="Skip if no CUDA device")
class CUDATester(Tester):

    def setUp(self):
        self.device = "cuda"
979

980
981
982
983
984
985
986
987
988
989
990
991
    def test_scale_channel(self):
        """Make sure that _scale_channel gives the same results on CPU and GPU as
        histc or bincount are used depending on the device.
        """
        # TODO: when # https://github.com/pytorch/pytorch/issues/53194 is fixed,
        # only use bincount and remove that test.
        size = (1_000,)
        img_chan = torch.randint(0, 256, size=size).to('cpu')
        scaled_cpu = F_t._scale_channel(img_chan)
        scaled_cuda = F_t._scale_channel(img_chan.to('cuda'))
        self.assertTrue(scaled_cpu.equal(scaled_cuda.to('cpu')))

992
993
994

if __name__ == '__main__':
    unittest.main()