Unverified Commit c4dcfb06 authored by vfdev's avatar vfdev Committed by GitHub
Browse files

Added tests on batch of tensors to check transforms (#2584)

* [WIP] Added tests on batch of tensors

* Updated tests on batch of images

* All functional transforms can work with (..., C, H, W) format

* Added transforms tests on batch tensors

* Added batch tests for five/ten crop
- updated docs
parent c36dc43b
...@@ -9,6 +9,11 @@ Functional transforms give fine-grained control over the transformations. ...@@ -9,6 +9,11 @@ Functional transforms give fine-grained control over the transformations.
This is useful if you have to build a more complex transformation pipeline This is useful if you have to build a more complex transformation pipeline
(e.g. in the case of segmentation tasks). (e.g. in the case of segmentation tasks).
All transformations accept PIL Image, Tensor Image or batch of Tensor Images as input. Tensor Image is a tensor with
``(C, H, W)`` shape, where ``C`` is a number of channels, ``H`` and ``W`` are image height and width. Batch of
Tensor Images is a tensor of ``(B, C, H, W)`` shape, where ``B`` is a number of images in the batch. Deterministic or
random transformations applied on the batch of Tensor Images identically transform all the images of the batch.
.. autoclass:: Compose .. autoclass:: Compose
Transforms on PIL Image Transforms on PIL Image
......
...@@ -341,6 +341,15 @@ class TransformsTester(unittest.TestCase): ...@@ -341,6 +341,15 @@ class TransformsTester(unittest.TestCase):
pil_img = Image.fromarray(tensor.permute(1, 2, 0).contiguous().cpu().numpy()) pil_img = Image.fromarray(tensor.permute(1, 2, 0).contiguous().cpu().numpy())
return tensor, pil_img return tensor, pil_img
def _create_data_batch(self, height=3, width=3, channels=3, num_samples=4, device="cpu"):
batch_tensor = torch.randint(
0, 255,
(num_samples, channels, height, width),
dtype=torch.uint8,
device=device
)
return batch_tensor
def compareTensorToPIL(self, tensor, pil_image, msg=None): def compareTensorToPIL(self, tensor, pil_image, msg=None):
np_pil_image = np.array(pil_image) np_pil_image = np.array(pil_image)
if np_pil_image.ndim == 2: if np_pil_image.ndim == 2:
......
...@@ -6,7 +6,6 @@ import numpy as np ...@@ -6,7 +6,6 @@ import numpy as np
from PIL.Image import NEAREST, BILINEAR, BICUBIC from PIL.Image import NEAREST, BILINEAR, BICUBIC
import torch import torch
import torchvision.transforms as transforms
import torchvision.transforms.functional_tensor as F_t import torchvision.transforms.functional_tensor as F_t
import torchvision.transforms.functional_pil as F_pil import torchvision.transforms.functional_pil as F_pil
import torchvision.transforms.functional as F import torchvision.transforms.functional as F
...@@ -19,31 +18,47 @@ class Tester(TransformsTester): ...@@ -19,31 +18,47 @@ class Tester(TransformsTester):
def setUp(self): def setUp(self):
self.device = "cpu" self.device = "cpu"
def _test_fn_on_batch(self, batch_tensors, fn, **fn_kwargs):
transformed_batch = fn(batch_tensors, **fn_kwargs)
for i in range(len(batch_tensors)):
img_tensor = batch_tensors[i, ...]
transformed_img = fn(img_tensor, **fn_kwargs)
self.assertTrue(transformed_img.equal(transformed_batch[i, ...]))
scripted_fn = torch.jit.script(fn)
# scriptable function test
s_transformed_batch = scripted_fn(batch_tensors, **fn_kwargs)
self.assertTrue(transformed_batch.allclose(s_transformed_batch))
def test_vflip(self): def test_vflip(self):
script_vflip = torch.jit.script(F_t.vflip) script_vflip = torch.jit.script(F.vflip)
img_tensor = torch.randn(3, 16, 16, device=self.device)
img_tensor_clone = img_tensor.clone() img_tensor, pil_img = self._create_data(16, 18, device=self.device)
vflipped_img = F_t.vflip(img_tensor) vflipped_img = F.vflip(img_tensor)
vflipped_img_again = F_t.vflip(vflipped_img) vflipped_pil_img = F.vflip(pil_img)
self.assertEqual(vflipped_img.shape, img_tensor.shape) self.compareTensorToPIL(vflipped_img, vflipped_pil_img)
self.assertTrue(torch.equal(img_tensor, vflipped_img_again))
self.assertTrue(torch.equal(img_tensor, img_tensor_clone))
# scriptable function test # scriptable function test
vflipped_img_script = script_vflip(img_tensor) vflipped_img_script = script_vflip(img_tensor)
self.assertTrue(torch.equal(vflipped_img, vflipped_img_script)) self.assertTrue(vflipped_img.equal(vflipped_img_script))
batch_tensors = self._create_data_batch(16, 18, num_samples=4, device=self.device)
self._test_fn_on_batch(batch_tensors, F.vflip)
def test_hflip(self): def test_hflip(self):
script_hflip = torch.jit.script(F_t.hflip) script_hflip = torch.jit.script(F.hflip)
img_tensor = torch.randn(3, 16, 16, device=self.device)
img_tensor_clone = img_tensor.clone() img_tensor, pil_img = self._create_data(16, 18, device=self.device)
hflipped_img = F_t.hflip(img_tensor) hflipped_img = F.hflip(img_tensor)
hflipped_img_again = F_t.hflip(hflipped_img) hflipped_pil_img = F.hflip(pil_img)
self.assertEqual(hflipped_img.shape, img_tensor.shape) self.compareTensorToPIL(hflipped_img, hflipped_pil_img)
self.assertTrue(torch.equal(img_tensor, hflipped_img_again))
self.assertTrue(torch.equal(img_tensor, img_tensor_clone))
# scriptable function test # scriptable function test
hflipped_img_script = script_hflip(img_tensor) hflipped_img_script = script_hflip(img_tensor)
self.assertTrue(torch.equal(hflipped_img, hflipped_img_script)) self.assertTrue(hflipped_img.equal(hflipped_img_script))
batch_tensors = self._create_data_batch(16, 18, num_samples=4, device=self.device)
self._test_fn_on_batch(batch_tensors, F.hflip)
def test_crop(self): def test_crop(self):
script_crop = torch.jit.script(F.crop) script_crop = torch.jit.script(F.crop)
...@@ -66,6 +81,9 @@ class Tester(TransformsTester): ...@@ -66,6 +81,9 @@ class Tester(TransformsTester):
img_tensor_cropped = script_crop(img_tensor, top, left, height, width) img_tensor_cropped = script_crop(img_tensor, top, left, height, width)
self.compareTensorToPIL(img_tensor_cropped, pil_img_cropped) self.compareTensorToPIL(img_tensor_cropped, pil_img_cropped)
batch_tensors = self._create_data_batch(16, 18, num_samples=4, device=self.device)
self._test_fn_on_batch(batch_tensors, F.crop, top=top, left=left, height=height, width=width)
def test_hsv2rgb(self): def test_hsv2rgb(self):
scripted_fn = torch.jit.script(F_t._hsv2rgb) scripted_fn = torch.jit.script(F_t._hsv2rgb)
shape = (3, 100, 150) shape = (3, 100, 150)
...@@ -89,6 +107,9 @@ class Tester(TransformsTester): ...@@ -89,6 +107,9 @@ class Tester(TransformsTester):
s_rgb_img = scripted_fn(hsv_img) s_rgb_img = scripted_fn(hsv_img)
self.assertTrue(rgb_img.allclose(s_rgb_img)) self.assertTrue(rgb_img.allclose(s_rgb_img))
batch_tensors = self._create_data_batch(120, 100, num_samples=4, device=self.device).float()
self._test_fn_on_batch(batch_tensors, F_t._hsv2rgb)
def test_rgb2hsv(self): def test_rgb2hsv(self):
scripted_fn = torch.jit.script(F_t._rgb2hsv) scripted_fn = torch.jit.script(F_t._rgb2hsv)
shape = (3, 150, 100) shape = (3, 150, 100)
...@@ -97,7 +118,7 @@ class Tester(TransformsTester): ...@@ -97,7 +118,7 @@ class Tester(TransformsTester):
hsv_img = F_t._rgb2hsv(rgb_img) hsv_img = F_t._rgb2hsv(rgb_img)
ft_hsv_img = hsv_img.permute(1, 2, 0).flatten(0, 1) ft_hsv_img = hsv_img.permute(1, 2, 0).flatten(0, 1)
r, g, b, = rgb_img.unbind(0) r, g, b, = rgb_img.unbind(dim=-3)
r = r.flatten().cpu().numpy() r = r.flatten().cpu().numpy()
g = g.flatten().cpu().numpy() g = g.flatten().cpu().numpy()
b = b.flatten().cpu().numpy() b = b.flatten().cpu().numpy()
...@@ -119,6 +140,9 @@ class Tester(TransformsTester): ...@@ -119,6 +140,9 @@ class Tester(TransformsTester):
s_hsv_img = scripted_fn(rgb_img) s_hsv_img = scripted_fn(rgb_img)
self.assertTrue(hsv_img.allclose(s_hsv_img)) self.assertTrue(hsv_img.allclose(s_hsv_img))
batch_tensors = self._create_data_batch(120, 100, num_samples=4, device=self.device).float()
self._test_fn_on_batch(batch_tensors, F_t._rgb2hsv)
def test_rgb_to_grayscale(self): def test_rgb_to_grayscale(self):
script_rgb_to_grayscale = torch.jit.script(F.rgb_to_grayscale) script_rgb_to_grayscale = torch.jit.script(F.rgb_to_grayscale)
...@@ -128,14 +152,14 @@ class Tester(TransformsTester): ...@@ -128,14 +152,14 @@ class Tester(TransformsTester):
gray_pil_image = F.rgb_to_grayscale(pil_img, num_output_channels=num_output_channels) gray_pil_image = F.rgb_to_grayscale(pil_img, num_output_channels=num_output_channels)
gray_tensor = F.rgb_to_grayscale(img_tensor, num_output_channels=num_output_channels) gray_tensor = F.rgb_to_grayscale(img_tensor, num_output_channels=num_output_channels)
if num_output_channels == 1:
print(gray_tensor.shape)
self.approxEqualTensorToPIL(gray_tensor.float(), gray_pil_image, tol=1.0 + 1e-10, agg_method="max") self.approxEqualTensorToPIL(gray_tensor.float(), gray_pil_image, tol=1.0 + 1e-10, agg_method="max")
s_gray_tensor = script_rgb_to_grayscale(img_tensor, num_output_channels=num_output_channels) s_gray_tensor = script_rgb_to_grayscale(img_tensor, num_output_channels=num_output_channels)
self.assertTrue(s_gray_tensor.equal(gray_tensor)) self.assertTrue(s_gray_tensor.equal(gray_tensor))
batch_tensors = self._create_data_batch(16, 18, num_samples=4, device=self.device)
self._test_fn_on_batch(batch_tensors, F.rgb_to_grayscale, num_output_channels=num_output_channels)
def test_center_crop(self): def test_center_crop(self):
script_center_crop = torch.jit.script(F.center_crop) script_center_crop = torch.jit.script(F.center_crop)
...@@ -149,6 +173,9 @@ class Tester(TransformsTester): ...@@ -149,6 +173,9 @@ class Tester(TransformsTester):
cropped_tensor = script_center_crop(img_tensor, [10, 11]) cropped_tensor = script_center_crop(img_tensor, [10, 11])
self.compareTensorToPIL(cropped_tensor, cropped_pil_image) self.compareTensorToPIL(cropped_tensor, cropped_pil_image)
batch_tensors = self._create_data_batch(16, 18, num_samples=4, device=self.device)
self._test_fn_on_batch(batch_tensors, F.center_crop, output_size=[10, 11])
def test_five_crop(self): def test_five_crop(self):
script_five_crop = torch.jit.script(F.five_crop) script_five_crop = torch.jit.script(F.five_crop)
...@@ -164,6 +191,23 @@ class Tester(TransformsTester): ...@@ -164,6 +191,23 @@ class Tester(TransformsTester):
for i in range(5): for i in range(5):
self.compareTensorToPIL(cropped_tensors[i], cropped_pil_images[i]) self.compareTensorToPIL(cropped_tensors[i], cropped_pil_images[i])
batch_tensors = self._create_data_batch(16, 18, num_samples=4, device=self.device)
tuple_transformed_batches = F.five_crop(batch_tensors, [10, 11])
for i in range(len(batch_tensors)):
img_tensor = batch_tensors[i, ...]
tuple_transformed_imgs = F.five_crop(img_tensor, [10, 11])
self.assertEqual(len(tuple_transformed_imgs), len(tuple_transformed_batches))
for j in range(len(tuple_transformed_imgs)):
true_transformed_img = tuple_transformed_imgs[j]
transformed_img = tuple_transformed_batches[j][i, ...]
self.assertTrue(true_transformed_img.equal(transformed_img))
# scriptable function test
s_tuple_transformed_batches = script_five_crop(batch_tensors, [10, 11])
for transformed_batch, s_transformed_batch in zip(tuple_transformed_batches, s_tuple_transformed_batches):
self.assertTrue(transformed_batch.equal(s_transformed_batch))
def test_ten_crop(self): def test_ten_crop(self):
script_ten_crop = torch.jit.script(F.ten_crop) script_ten_crop = torch.jit.script(F.ten_crop)
...@@ -179,9 +223,27 @@ class Tester(TransformsTester): ...@@ -179,9 +223,27 @@ class Tester(TransformsTester):
for i in range(10): for i in range(10):
self.compareTensorToPIL(cropped_tensors[i], cropped_pil_images[i]) self.compareTensorToPIL(cropped_tensors[i], cropped_pil_images[i])
batch_tensors = self._create_data_batch(16, 18, num_samples=4, device=self.device)
tuple_transformed_batches = F.ten_crop(batch_tensors, [10, 11])
for i in range(len(batch_tensors)):
img_tensor = batch_tensors[i, ...]
tuple_transformed_imgs = F.ten_crop(img_tensor, [10, 11])
self.assertEqual(len(tuple_transformed_imgs), len(tuple_transformed_batches))
for j in range(len(tuple_transformed_imgs)):
true_transformed_img = tuple_transformed_imgs[j]
transformed_img = tuple_transformed_batches[j][i, ...]
self.assertTrue(true_transformed_img.equal(transformed_img))
# scriptable function test
s_tuple_transformed_batches = script_ten_crop(batch_tensors, [10, 11])
for transformed_batch, s_transformed_batch in zip(tuple_transformed_batches, s_tuple_transformed_batches):
self.assertTrue(transformed_batch.equal(s_transformed_batch))
def test_pad(self): def test_pad(self):
script_fn = torch.jit.script(F_t.pad) script_fn = torch.jit.script(F.pad)
tensor, pil_img = self._create_data(7, 8, device=self.device) tensor, pil_img = self._create_data(7, 8, device=self.device)
batch_tensors = self._create_data_batch(16, 18, num_samples=4, device=self.device)
for dt in [None, torch.float32, torch.float64, torch.float16]: for dt in [None, torch.float32, torch.float64, torch.float16]:
...@@ -192,6 +254,8 @@ class Tester(TransformsTester): ...@@ -192,6 +254,8 @@ class Tester(TransformsTester):
if dt is not None: if dt is not None:
# This is a trivial cast to float of uint8 data to test all cases # This is a trivial cast to float of uint8 data to test all cases
tensor = tensor.to(dt) tensor = tensor.to(dt)
batch_tensors = batch_tensors.to(dt)
for pad in [2, [3, ], [0, 3], (3, 3), [4, 2, 4, 3]]: for pad in [2, [3, ], [0, 3], (3, 3), [4, 2, 4, 3]]:
configs = [ configs = [
{"padding_mode": "constant", "fill": 0}, {"padding_mode": "constant", "fill": 0},
...@@ -219,6 +283,8 @@ class Tester(TransformsTester): ...@@ -219,6 +283,8 @@ class Tester(TransformsTester):
pad_tensor_script = script_fn(tensor, script_pad, **kwargs) pad_tensor_script = script_fn(tensor, script_pad, **kwargs)
self.assertTrue(pad_tensor.equal(pad_tensor_script), msg="{}, {}".format(pad, kwargs)) self.assertTrue(pad_tensor.equal(pad_tensor_script), msg="{}, {}".format(pad, kwargs))
self._test_fn_on_batch(batch_tensors, F.pad, padding=script_pad, **kwargs)
with self.assertRaises(ValueError, msg="Padding can not be negative for symmetric padding_mode"): with self.assertRaises(ValueError, msg="Padding can not be negative for symmetric padding_mode"):
F_t.pad(tensor, (-2, -3), padding_mode="symmetric") F_t.pad(tensor, (-2, -3), padding_mode="symmetric")
...@@ -226,11 +292,13 @@ class Tester(TransformsTester): ...@@ -226,11 +292,13 @@ class Tester(TransformsTester):
script_fn = torch.jit.script(fn) script_fn = torch.jit.script(fn)
torch.manual_seed(15) torch.manual_seed(15)
tensor, pil_img = self._create_data(26, 34, device=self.device) tensor, pil_img = self._create_data(26, 34, device=self.device)
batch_tensors = self._create_data_batch(16, 18, num_samples=4, device=self.device)
for dt in [None, torch.float32, torch.float64]: for dt in [None, torch.float32, torch.float64]:
if dt is not None: if dt is not None:
tensor = F.convert_image_dtype(tensor, dt) tensor = F.convert_image_dtype(tensor, dt)
batch_tensors = F.convert_image_dtype(batch_tensors, dt)
for config in configs: for config in configs:
adjusted_tensor = fn_t(tensor, **config) adjusted_tensor = fn_t(tensor, **config)
...@@ -254,6 +322,8 @@ class Tester(TransformsTester): ...@@ -254,6 +322,8 @@ class Tester(TransformsTester):
atol = 1.0 atol = 1.0
self.assertTrue(adjusted_tensor.allclose(scripted_result, atol=atol), msg=msg) self.assertTrue(adjusted_tensor.allclose(scripted_result, atol=atol), msg=msg)
self._test_fn_on_batch(batch_tensors, fn, **config)
def test_adjust_brightness(self): def test_adjust_brightness(self):
self._test_adjust_fn( self._test_adjust_fn(
F.adjust_brightness, F.adjust_brightness,
...@@ -299,6 +369,7 @@ class Tester(TransformsTester): ...@@ -299,6 +369,7 @@ class Tester(TransformsTester):
def test_resize(self): def test_resize(self):
script_fn = torch.jit.script(F_t.resize) script_fn = torch.jit.script(F_t.resize)
tensor, pil_img = self._create_data(26, 36, device=self.device) tensor, pil_img = self._create_data(26, 36, device=self.device)
batch_tensors = self._create_data_batch(16, 18, num_samples=4, device=self.device)
for dt in [None, torch.float32, torch.float64, torch.float16]: for dt in [None, torch.float32, torch.float64, torch.float16]:
...@@ -309,6 +380,8 @@ class Tester(TransformsTester): ...@@ -309,6 +380,8 @@ class Tester(TransformsTester):
if dt is not None: if dt is not None:
# This is a trivial cast to float of uint8 data to test all cases # This is a trivial cast to float of uint8 data to test all cases
tensor = tensor.to(dt) tensor = tensor.to(dt)
batch_tensors = batch_tensors.to(dt)
for size in [32, 26, [32, ], [32, 32], (32, 32), [26, 35]]: for size in [32, 26, [32, ], [32, 32], (32, 32), [26, 35]]:
for interpolation in [BILINEAR, BICUBIC, NEAREST]: for interpolation in [BILINEAR, BICUBIC, NEAREST]:
resized_tensor = F_t.resize(tensor, size=size, interpolation=interpolation) resized_tensor = F_t.resize(tensor, size=size, interpolation=interpolation)
...@@ -339,6 +412,10 @@ class Tester(TransformsTester): ...@@ -339,6 +412,10 @@ class Tester(TransformsTester):
resize_result = script_fn(tensor, size=script_size, interpolation=interpolation) resize_result = script_fn(tensor, size=script_size, interpolation=interpolation)
self.assertTrue(resized_tensor.equal(resize_result), msg="{}, {}".format(size, interpolation)) self.assertTrue(resized_tensor.equal(resize_result), msg="{}, {}".format(size, interpolation))
self._test_fn_on_batch(
batch_tensors, F.resize, size=script_size, interpolation=interpolation
)
def test_resized_crop(self): def test_resized_crop(self):
# test values of F.resized_crop in several cases: # test values of F.resized_crop in several cases:
# 1) resize to the same size, crop to the same size => should be identity # 1) resize to the same size, crop to the same size => should be identity
...@@ -356,6 +433,11 @@ class Tester(TransformsTester): ...@@ -356,6 +433,11 @@ class Tester(TransformsTester):
msg="{} vs {}".format(expected_out_tensor[0, :10, :10], out_tensor[0, :10, :10]) msg="{} vs {}".format(expected_out_tensor[0, :10, :10], out_tensor[0, :10, :10])
) )
batch_tensors = self._create_data_batch(26, 36, num_samples=4, device=self.device)
self._test_fn_on_batch(
batch_tensors, F.resized_crop, top=1, left=2, height=20, width=30, size=[10, 15], interpolation=0
)
def _test_affine_identity_map(self, tensor, scripted_affine): def _test_affine_identity_map(self, tensor, scripted_affine):
# 1) identity map # 1) identity map
out_tensor = F.affine(tensor, angle=0, translate=[0, 0], scale=1.0, shear=[0.0, 0.0], resample=0) out_tensor = F.affine(tensor, angle=0, translate=[0, 0], scale=1.0, shear=[0.0, 0.0], resample=0)
...@@ -515,31 +597,19 @@ class Tester(TransformsTester): ...@@ -515,31 +597,19 @@ class Tester(TransformsTester):
else: else:
self._test_affine_rect_rotations(tensor, pil_img, scripted_affine) self._test_affine_rect_rotations(tensor, pil_img, scripted_affine)
self._test_affine_translations(tensor, pil_img, scripted_affine) self._test_affine_translations(tensor, pil_img, scripted_affine)
# self._test_affine_all_ops(tensor, pil_img, scripted_affine) self._test_affine_all_ops(tensor, pil_img, scripted_affine)
def test_rotate(self): batch_tensors = self._create_data_batch(26, 36, num_samples=4, device=self.device)
# Tests on square image if dt is not None:
scripted_rotate = torch.jit.script(F.rotate) batch_tensors = batch_tensors.to(dtype=dt)
data = [self._create_data(26, 26, device=self.device), self._create_data(32, 26, device=self.device)] self._test_fn_on_batch(
for tensor, pil_img in data: batch_tensors, F.affine, angle=-43, translate=[-3, 4], scale=1.2, shear=[4.0, 5.0]
)
def _test_rotate_all_options(self, tensor, pil_img, scripted_rotate, centers):
img_size = pil_img.size img_size = pil_img.size
centers = [ dt = tensor.dtype
None,
(int(img_size[0] * 0.3), int(img_size[0] * 0.4)),
[int(img_size[0] * 0.5), int(img_size[0] * 0.6)]
]
for dt in [None, torch.float32, torch.float64, torch.float16]:
if dt == torch.float16 and torch.device(self.device).type == "cpu":
# skip float16 on CPU case
continue
if dt is not None:
tensor = tensor.to(dtype=dt)
for r in [0, ]: for r in [0, ]:
for a in range(-180, 180, 17): for a in range(-180, 180, 17):
for e in [True, False]: for e in [True, False]:
...@@ -574,23 +644,18 @@ class Tester(TransformsTester): ...@@ -574,23 +644,18 @@ class Tester(TransformsTester):
) )
) )
def test_perspective(self): def test_rotate(self):
# Tests on square image
from torchvision.transforms import RandomPerspective scripted_rotate = torch.jit.script(F.rotate)
data = [self._create_data(26, 34, device=self.device), self._create_data(26, 26, device=self.device)]
scripted_tranform = torch.jit.script(F.perspective)
data = [self._create_data(26, 26, device=self.device), self._create_data(32, 26, device=self.device)]
for tensor, pil_img in data: for tensor, pil_img in data:
test_configs = [ img_size = pil_img.size
[[[0, 0], [33, 0], [33, 25], [0, 25]], [[3, 2], [32, 3], [30, 24], [2, 25]]], centers = [
[[[3, 2], [32, 3], [30, 24], [2, 25]], [[0, 0], [33, 0], [33, 25], [0, 25]]], None,
[[[3, 2], [32, 3], [30, 24], [2, 25]], [[5, 5], [30, 3], [33, 19], [4, 25]]], (int(img_size[0] * 0.3), int(img_size[0] * 0.4)),
] [int(img_size[0] * 0.5), int(img_size[0] * 0.6)]
n = 10
test_configs += [
RandomPerspective.get_params(pil_img.size[0], pil_img.size[1], i / n) for i in range(n)
] ]
for dt in [None, torch.float32, torch.float64, torch.float16]: for dt in [None, torch.float32, torch.float64, torch.float16]:
...@@ -602,6 +667,19 @@ class Tester(TransformsTester): ...@@ -602,6 +667,19 @@ class Tester(TransformsTester):
if dt is not None: if dt is not None:
tensor = tensor.to(dtype=dt) tensor = tensor.to(dtype=dt)
self._test_rotate_all_options(tensor, pil_img, scripted_rotate, centers)
batch_tensors = self._create_data_batch(26, 36, num_samples=4, device=self.device)
if dt is not None:
batch_tensors = batch_tensors.to(dtype=dt)
center = (20, 22)
self._test_fn_on_batch(
batch_tensors, F.rotate, angle=32, resample=0, expand=True, center=center
)
def _test_perspective(self, tensor, pil_img, scripted_tranform, test_configs):
dt = tensor.dtype
for r in [0, ]: for r in [0, ]:
for spoints, epoints in test_configs: for spoints, epoints in test_configs:
out_pil_img = F.perspective(pil_img, startpoints=spoints, endpoints=epoints, interpolation=r) out_pil_img = F.perspective(pil_img, startpoints=spoints, endpoints=epoints, interpolation=r)
...@@ -627,6 +705,45 @@ class Tester(TransformsTester): ...@@ -627,6 +705,45 @@ class Tester(TransformsTester):
) )
) )
def test_perspective(self):
from torchvision.transforms import RandomPerspective
data = [self._create_data(26, 34, device=self.device), self._create_data(26, 26, device=self.device)]
scripted_tranform = torch.jit.script(F.perspective)
for tensor, pil_img in data:
test_configs = [
[[[0, 0], [33, 0], [33, 25], [0, 25]], [[3, 2], [32, 3], [30, 24], [2, 25]]],
[[[3, 2], [32, 3], [30, 24], [2, 25]], [[0, 0], [33, 0], [33, 25], [0, 25]]],
[[[3, 2], [32, 3], [30, 24], [2, 25]], [[5, 5], [30, 3], [33, 19], [4, 25]]],
]
n = 10
test_configs += [
RandomPerspective.get_params(pil_img.size[0], pil_img.size[1], i / n) for i in range(n)
]
for dt in [None, torch.float32, torch.float64, torch.float16]:
if dt == torch.float16 and torch.device(self.device).type == "cpu":
# skip float16 on CPU case
continue
if dt is not None:
tensor = tensor.to(dtype=dt)
self._test_perspective(tensor, pil_img, scripted_tranform, test_configs)
batch_tensors = self._create_data_batch(26, 36, num_samples=4, device=self.device)
if dt is not None:
batch_tensors = batch_tensors.to(dtype=dt)
for spoints, epoints in test_configs:
self._test_fn_on_batch(
batch_tensors, F.perspective, startpoints=spoints, endpoints=epoints, interpolation=0
)
@unittest.skipIf(not torch.cuda.is_available(), reason="Skip if no CUDA device") @unittest.skipIf(not torch.cuda.is_available(), reason="Skip if no CUDA device")
class CUDATester(Tester): class CUDATester(Tester):
......
...@@ -19,20 +19,43 @@ class Tester(TransformsTester): ...@@ -19,20 +19,43 @@ class Tester(TransformsTester):
def _test_functional_op(self, func, fn_kwargs): def _test_functional_op(self, func, fn_kwargs):
if fn_kwargs is None: if fn_kwargs is None:
fn_kwargs = {} fn_kwargs = {}
f = getattr(F, func)
tensor, pil_img = self._create_data(height=10, width=10, device=self.device) tensor, pil_img = self._create_data(height=10, width=10, device=self.device)
transformed_tensor = getattr(F, func)(tensor, **fn_kwargs) transformed_tensor = f(tensor, **fn_kwargs)
transformed_pil_img = getattr(F, func)(pil_img, **fn_kwargs) transformed_pil_img = f(pil_img, **fn_kwargs)
self.compareTensorToPIL(transformed_tensor, transformed_pil_img) self.compareTensorToPIL(transformed_tensor, transformed_pil_img)
def _test_transform_vs_scripted(self, transform, s_transform, tensor):
torch.manual_seed(12)
out1 = transform(tensor)
torch.manual_seed(12)
out2 = s_transform(tensor)
self.assertTrue(out1.equal(out2))
def _test_transform_vs_scripted_on_batch(self, transform, s_transform, batch_tensors):
torch.manual_seed(12)
transformed_batch = transform(batch_tensors)
for i in range(len(batch_tensors)):
img_tensor = batch_tensors[i, ...]
torch.manual_seed(12)
transformed_img = transform(img_tensor)
self.assertTrue(transformed_img.equal(transformed_batch[i, ...]))
torch.manual_seed(12)
s_transformed_batch = s_transform(batch_tensors)
self.assertTrue(transformed_batch.equal(s_transformed_batch))
def _test_class_op(self, method, meth_kwargs=None, test_exact_match=True, **match_kwargs): def _test_class_op(self, method, meth_kwargs=None, test_exact_match=True, **match_kwargs):
if meth_kwargs is None: if meth_kwargs is None:
meth_kwargs = {} meth_kwargs = {}
tensor, pil_img = self._create_data(26, 34, device=self.device)
# test for class interface # test for class interface
f = getattr(T, method)(**meth_kwargs) f = getattr(T, method)(**meth_kwargs)
scripted_fn = torch.jit.script(f) scripted_fn = torch.jit.script(f)
tensor, pil_img = self._create_data(26, 34, device=self.device)
# set seed to reproduce the same transformation for tensor and PIL image # set seed to reproduce the same transformation for tensor and PIL image
torch.manual_seed(12) torch.manual_seed(12)
transformed_tensor = f(tensor) transformed_tensor = f(tensor)
...@@ -47,6 +70,9 @@ class Tester(TransformsTester): ...@@ -47,6 +70,9 @@ class Tester(TransformsTester):
transformed_tensor_script = scripted_fn(tensor) transformed_tensor_script = scripted_fn(tensor)
self.assertTrue(transformed_tensor.equal(transformed_tensor_script)) self.assertTrue(transformed_tensor.equal(transformed_tensor_script))
batch_tensors = self._create_data_batch(height=23, width=34, channels=3, num_samples=4, device=self.device)
self._test_transform_vs_scripted_on_batch(f, scripted_fn, batch_tensors)
def _test_op(self, func, method, fn_kwargs=None, meth_kwargs=None): def _test_op(self, func, method, fn_kwargs=None, meth_kwargs=None):
self._test_functional_op(func, fn_kwargs) self._test_functional_op(func, fn_kwargs)
self._test_class_op(method, meth_kwargs) self._test_class_op(method, meth_kwargs)
...@@ -167,15 +193,18 @@ class Tester(TransformsTester): ...@@ -167,15 +193,18 @@ class Tester(TransformsTester):
fn_kwargs = {} fn_kwargs = {}
if meth_kwargs is None: if meth_kwargs is None:
meth_kwargs = {} meth_kwargs = {}
fn = getattr(F, func)
scripted_fn = torch.jit.script(fn)
tensor, pil_img = self._create_data(height=20, width=20, device=self.device) tensor, pil_img = self._create_data(height=20, width=20, device=self.device)
transformed_t_list = getattr(F, func)(tensor, **fn_kwargs) transformed_t_list = fn(tensor, **fn_kwargs)
transformed_p_list = getattr(F, func)(pil_img, **fn_kwargs) transformed_p_list = fn(pil_img, **fn_kwargs)
self.assertEqual(len(transformed_t_list), len(transformed_p_list)) self.assertEqual(len(transformed_t_list), len(transformed_p_list))
self.assertEqual(len(transformed_t_list), out_length) self.assertEqual(len(transformed_t_list), out_length)
for transformed_tensor, transformed_pil_img in zip(transformed_t_list, transformed_p_list): for transformed_tensor, transformed_pil_img in zip(transformed_t_list, transformed_p_list):
self.compareTensorToPIL(transformed_tensor, transformed_pil_img) self.compareTensorToPIL(transformed_tensor, transformed_pil_img)
scripted_fn = torch.jit.script(getattr(F, func))
transformed_t_list_script = scripted_fn(tensor.detach().clone(), **fn_kwargs) transformed_t_list_script = scripted_fn(tensor.detach().clone(), **fn_kwargs)
self.assertEqual(len(transformed_t_list), len(transformed_t_list_script)) self.assertEqual(len(transformed_t_list), len(transformed_t_list_script))
self.assertEqual(len(transformed_t_list_script), out_length) self.assertEqual(len(transformed_t_list_script), out_length)
...@@ -184,11 +213,24 @@ class Tester(TransformsTester): ...@@ -184,11 +213,24 @@ class Tester(TransformsTester):
msg="{} vs {}".format(transformed_tensor, transformed_tensor_script)) msg="{} vs {}".format(transformed_tensor, transformed_tensor_script))
# test for class interface # test for class interface
f = getattr(T, method)(**meth_kwargs) fn = getattr(T, method)(**meth_kwargs)
scripted_fn = torch.jit.script(f) scripted_fn = torch.jit.script(fn)
output = scripted_fn(tensor) output = scripted_fn(tensor)
self.assertEqual(len(output), len(transformed_t_list_script)) self.assertEqual(len(output), len(transformed_t_list_script))
# test on batch of tensors
batch_tensors = self._create_data_batch(height=23, width=34, channels=3, num_samples=4, device=self.device)
torch.manual_seed(12)
transformed_batch_list = fn(batch_tensors)
for i in range(len(batch_tensors)):
img_tensor = batch_tensors[i, ...]
torch.manual_seed(12)
transformed_img_list = fn(img_tensor)
for transformed_img, transformed_batch in zip(transformed_img_list, transformed_batch_list):
self.assertTrue(transformed_img.equal(transformed_batch[i, ...]),
msg="{} vs {}".format(transformed_img, transformed_batch[i, ...]))
def test_five_crop(self): def test_five_crop(self):
fn_kwargs = meth_kwargs = {"size": (5,)} fn_kwargs = meth_kwargs = {"size": (5,)}
self._test_op_list_output( self._test_op_list_output(
...@@ -227,6 +269,7 @@ class Tester(TransformsTester): ...@@ -227,6 +269,7 @@ class Tester(TransformsTester):
def test_resize(self): def test_resize(self):
tensor, _ = self._create_data(height=34, width=36, device=self.device) tensor, _ = self._create_data(height=34, width=36, device=self.device)
batch_tensors = torch.randint(0, 255, size=(4, 3, 44, 56), dtype=torch.uint8, device=self.device)
script_fn = torch.jit.script(F.resize) script_fn = torch.jit.script(F.resize)
for dt in [None, torch.float32, torch.float64]: for dt in [None, torch.float32, torch.float64]:
...@@ -247,13 +290,13 @@ class Tester(TransformsTester): ...@@ -247,13 +290,13 @@ class Tester(TransformsTester):
self.assertTrue(s_resized_tensor.equal(resized_tensor)) self.assertTrue(s_resized_tensor.equal(resized_tensor))
transform = T.Resize(size=script_size, interpolation=interpolation) transform = T.Resize(size=script_size, interpolation=interpolation)
resized_tensor = transform(tensor) s_transform = torch.jit.script(transform)
script_transform = torch.jit.script(transform) self._test_transform_vs_scripted(transform, s_transform, tensor)
s_resized_tensor = script_transform(tensor) self._test_transform_vs_scripted_on_batch(transform, s_transform, batch_tensors)
self.assertTrue(s_resized_tensor.equal(resized_tensor))
def test_resized_crop(self): def test_resized_crop(self):
tensor = torch.randint(0, 255, size=(3, 44, 56), dtype=torch.uint8, device=self.device) tensor = torch.randint(0, 255, size=(3, 44, 56), dtype=torch.uint8, device=self.device)
batch_tensors = torch.randint(0, 255, size=(4, 3, 44, 56), dtype=torch.uint8, device=self.device)
for scale in [(0.7, 1.2), [0.7, 1.2]]: for scale in [(0.7, 1.2), [0.7, 1.2]]:
for ratio in [(0.75, 1.333), [0.75, 1.333]]: for ratio in [(0.75, 1.333), [0.75, 1.333]]:
...@@ -263,15 +306,12 @@ class Tester(TransformsTester): ...@@ -263,15 +306,12 @@ class Tester(TransformsTester):
size=size, scale=scale, ratio=ratio, interpolation=interpolation size=size, scale=scale, ratio=ratio, interpolation=interpolation
) )
s_transform = torch.jit.script(transform) s_transform = torch.jit.script(transform)
self._test_transform_vs_scripted(transform, s_transform, tensor)
torch.manual_seed(12) self._test_transform_vs_scripted_on_batch(transform, s_transform, batch_tensors)
out1 = transform(tensor)
torch.manual_seed(12)
out2 = s_transform(tensor)
self.assertTrue(out1.equal(out2))
def test_random_affine(self): def test_random_affine(self):
tensor = torch.randint(0, 255, size=(3, 44, 56), dtype=torch.uint8, device=self.device) tensor = torch.randint(0, 255, size=(3, 44, 56), dtype=torch.uint8, device=self.device)
batch_tensors = torch.randint(0, 255, size=(4, 3, 44, 56), dtype=torch.uint8, device=self.device)
for shear in [15, 10.0, (5.0, 10.0), [-15, 15], [-10.0, 10.0, -11.0, 11.0]]: for shear in [15, 10.0, (5.0, 10.0), [-15, 15], [-10.0, 10.0, -11.0, 11.0]]:
for scale in [(0.7, 1.2), [0.7, 1.2]]: for scale in [(0.7, 1.2), [0.7, 1.2]]:
...@@ -284,14 +324,12 @@ class Tester(TransformsTester): ...@@ -284,14 +324,12 @@ class Tester(TransformsTester):
) )
s_transform = torch.jit.script(transform) s_transform = torch.jit.script(transform)
torch.manual_seed(12) self._test_transform_vs_scripted(transform, s_transform, tensor)
out1 = transform(tensor) self._test_transform_vs_scripted_on_batch(transform, s_transform, batch_tensors)
torch.manual_seed(12)
out2 = s_transform(tensor)
self.assertTrue(out1.equal(out2))
def test_random_rotate(self): def test_random_rotate(self):
tensor = torch.randint(0, 255, size=(3, 44, 56), dtype=torch.uint8, device=self.device) tensor = torch.randint(0, 255, size=(3, 44, 56), dtype=torch.uint8, device=self.device)
batch_tensors = torch.randint(0, 255, size=(4, 3, 44, 56), dtype=torch.uint8, device=self.device)
for center in [(0, 0), [10, 10], None, (56, 44)]: for center in [(0, 0), [10, 10], None, (56, 44)]:
for expand in [True, False]: for expand in [True, False]:
...@@ -302,14 +340,12 @@ class Tester(TransformsTester): ...@@ -302,14 +340,12 @@ class Tester(TransformsTester):
) )
s_transform = torch.jit.script(transform) s_transform = torch.jit.script(transform)
torch.manual_seed(12) self._test_transform_vs_scripted(transform, s_transform, tensor)
out1 = transform(tensor) self._test_transform_vs_scripted_on_batch(transform, s_transform, batch_tensors)
torch.manual_seed(12)
out2 = s_transform(tensor)
self.assertTrue(out1.equal(out2))
def test_random_perspective(self): def test_random_perspective(self):
tensor = torch.randint(0, 255, size=(3, 44, 56), dtype=torch.uint8, device=self.device) tensor = torch.randint(0, 255, size=(3, 44, 56), dtype=torch.uint8, device=self.device)
batch_tensors = torch.randint(0, 255, size=(4, 3, 44, 56), dtype=torch.uint8, device=self.device)
for distortion_scale in np.linspace(0.1, 1.0, num=20): for distortion_scale in np.linspace(0.1, 1.0, num=20):
for interpolation in [NEAREST, BILINEAR]: for interpolation in [NEAREST, BILINEAR]:
...@@ -319,11 +355,8 @@ class Tester(TransformsTester): ...@@ -319,11 +355,8 @@ class Tester(TransformsTester):
) )
s_transform = torch.jit.script(transform) s_transform = torch.jit.script(transform)
torch.manual_seed(12) self._test_transform_vs_scripted(transform, s_transform, tensor)
out1 = transform(tensor) self._test_transform_vs_scripted_on_batch(transform, s_transform, batch_tensors)
torch.manual_seed(12)
out2 = s_transform(tensor)
self.assertTrue(out1.equal(out2))
def test_to_grayscale(self): def test_to_grayscale(self):
......
...@@ -36,7 +36,7 @@ def vflip(img: Tensor) -> Tensor: ...@@ -36,7 +36,7 @@ def vflip(img: Tensor) -> Tensor:
Please, consider instead using methods from `transforms.functional` module. Please, consider instead using methods from `transforms.functional` module.
Args: Args:
img (Tensor): Image Tensor to be flipped in the form [C, H, W]. img (Tensor): Image Tensor to be flipped in the form [..., C, H, W].
Returns: Returns:
Tensor: Vertically flipped image Tensor. Tensor: Vertically flipped image Tensor.
...@@ -56,7 +56,7 @@ def hflip(img: Tensor) -> Tensor: ...@@ -56,7 +56,7 @@ def hflip(img: Tensor) -> Tensor:
Please, consider instead using methods from `transforms.functional` module. Please, consider instead using methods from `transforms.functional` module.
Args: Args:
img (Tensor): Image Tensor to be flipped in the form [C, H, W]. img (Tensor): Image Tensor to be flipped in the form [..., C, H, W].
Returns: Returns:
Tensor: Horizontally flipped image Tensor. Tensor: Horizontally flipped image Tensor.
...@@ -183,7 +183,8 @@ def adjust_contrast(img: Tensor, contrast_factor: float) -> Tensor: ...@@ -183,7 +183,8 @@ def adjust_contrast(img: Tensor, contrast_factor: float) -> Tensor:
if not _is_tensor_a_torch_image(img): if not _is_tensor_a_torch_image(img):
raise TypeError('tensor is not a torch image.') raise TypeError('tensor is not a torch image.')
mean = torch.mean(rgb_to_grayscale(img).to(torch.float)) dtype = img.dtype if torch.is_floating_point(img) else torch.float32
mean = torch.mean(rgb_to_grayscale(img).to(dtype), dim=(-3, -2, -1), keepdim=True)
return _blend(img, mean, contrast_factor) return _blend(img, mean, contrast_factor)
...@@ -229,9 +230,9 @@ def adjust_hue(img: Tensor, hue_factor: float) -> Tensor: ...@@ -229,9 +230,9 @@ def adjust_hue(img: Tensor, hue_factor: float) -> Tensor:
img = img.to(dtype=torch.float32) / 255.0 img = img.to(dtype=torch.float32) / 255.0
img = _rgb2hsv(img) img = _rgb2hsv(img)
h, s, v = img.unbind(0) h, s, v = img.unbind(dim=-3)
h = (h + hue_factor) % 1.0 h = (h + hue_factor) % 1.0
img = torch.stack((h, s, v)) img = torch.stack((h, s, v), dim=-3)
img_hue_adj = _hsv2rgb(img) img_hue_adj = _hsv2rgb(img)
if orig_dtype == torch.uint8: if orig_dtype == torch.uint8:
...@@ -466,12 +467,12 @@ def _blend(img1: Tensor, img2: Tensor, ratio: float) -> Tensor: ...@@ -466,12 +467,12 @@ def _blend(img1: Tensor, img2: Tensor, ratio: float) -> Tensor:
def _rgb2hsv(img): def _rgb2hsv(img):
r, g, b = img.unbind(0) r, g, b = img.unbind(dim=-3)
# Implementation is based on https://github.com/python-pillow/Pillow/blob/4174d4267616897df3746d315d5a2d0f82c656ee/ # Implementation is based on https://github.com/python-pillow/Pillow/blob/4174d4267616897df3746d315d5a2d0f82c656ee/
# src/libImaging/Convert.c#L330 # src/libImaging/Convert.c#L330
maxc = torch.max(img, dim=0).values maxc = torch.max(img, dim=-3).values
minc = torch.min(img, dim=0).values minc = torch.min(img, dim=-3).values
# The algorithm erases S and H channel where `maxc = minc`. This avoids NaN # The algorithm erases S and H channel where `maxc = minc`. This avoids NaN
# from happening in the results, because # from happening in the results, because
...@@ -501,11 +502,11 @@ def _rgb2hsv(img): ...@@ -501,11 +502,11 @@ def _rgb2hsv(img):
hb = ((maxc != g) & (maxc != r)) * (4.0 + gc - rc) hb = ((maxc != g) & (maxc != r)) * (4.0 + gc - rc)
h = (hr + hg + hb) h = (hr + hg + hb)
h = torch.fmod((h / 6.0 + 1.0), 1.0) h = torch.fmod((h / 6.0 + 1.0), 1.0)
return torch.stack((h, s, maxc)) return torch.stack((h, s, maxc), dim=-3)
def _hsv2rgb(img): def _hsv2rgb(img):
h, s, v = img.unbind(0) h, s, v = img.unbind(dim=-3)
i = torch.floor(h * 6.0) i = torch.floor(h * 6.0)
f = (h * 6.0) - i f = (h * 6.0) - i
i = i.to(dtype=torch.int32) i = i.to(dtype=torch.int32)
...@@ -515,14 +516,14 @@ def _hsv2rgb(img): ...@@ -515,14 +516,14 @@ def _hsv2rgb(img):
t = torch.clamp((v * (1.0 - s * (1.0 - f))), 0.0, 1.0) t = torch.clamp((v * (1.0 - s * (1.0 - f))), 0.0, 1.0)
i = i % 6 i = i % 6
mask = i == torch.arange(6, device=i.device)[:, None, None] mask = i.unsqueeze(dim=-3) == torch.arange(6, device=i.device).view(-1, 1, 1)
a1 = torch.stack((v, q, p, p, t, v)) a1 = torch.stack((v, q, p, p, t, v), dim=-3)
a2 = torch.stack((t, v, v, q, p, p)) a2 = torch.stack((t, v, v, q, p, p), dim=-3)
a3 = torch.stack((p, p, t, v, v, q)) a3 = torch.stack((p, p, t, v, v, q), dim=-3)
a4 = torch.stack((a1, a2, a3)) a4 = torch.stack((a1, a2, a3), dim=-4)
return torch.einsum("ijk, xijk -> xjk", mask.to(dtype=img.dtype), a4) return torch.einsum("...ijk, ...xijk -> ...xjk", mask.to(dtype=img.dtype), a4)
def _pad_symmetric(img: Tensor, padding: List[int]) -> Tensor: def _pad_symmetric(img: Tensor, padding: List[int]) -> Tensor:
...@@ -793,6 +794,9 @@ def _apply_grid_transform(img: Tensor, grid: Tensor, mode: str) -> Tensor: ...@@ -793,6 +794,9 @@ def _apply_grid_transform(img: Tensor, grid: Tensor, mode: str) -> Tensor:
need_cast = True need_cast = True
img = img.to(grid) img = img.to(grid)
if img.shape[0] > 1:
# Apply same grid to a batch of images
grid = grid.expand(img.shape[0], grid.shape[1], grid.shape[2], grid.shape[3])
img = grid_sample(img, grid, mode=mode, padding_mode="zeros", align_corners=False) img = grid_sample(img, grid, mode=mode, padding_mode="zeros", align_corners=False)
if need_squeeze: if need_squeeze:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment