Unverified Commit c4dcfb06 authored by vfdev's avatar vfdev Committed by GitHub
Browse files

Added tests on batch of tensors to check transforms (#2584)

* [WIP] Added tests on batch of tensors

* Updated tests on batch of images

* All functional transforms can work with (..., C, H, W) format

* Added transforms tests on batch tensors

* Added batch tests for five/ten crop
- updated docs
parent c36dc43b
......@@ -9,6 +9,11 @@ Functional transforms give fine-grained control over the transformations.
This is useful if you have to build a more complex transformation pipeline
(e.g. in the case of segmentation tasks).
All transformations accept PIL Image, Tensor Image or batch of Tensor Images as input. Tensor Image is a tensor with
``(C, H, W)`` shape, where ``C`` is a number of channels, ``H`` and ``W`` are image height and width. Batch of
Tensor Images is a tensor of ``(B, C, H, W)`` shape, where ``B`` is a number of images in the batch. Deterministic or
random transformations applied on the batch of Tensor Images identically transform all the images of the batch.
.. autoclass:: Compose
Transforms on PIL Image
......
......@@ -341,6 +341,15 @@ class TransformsTester(unittest.TestCase):
pil_img = Image.fromarray(tensor.permute(1, 2, 0).contiguous().cpu().numpy())
return tensor, pil_img
def _create_data_batch(self, height=3, width=3, channels=3, num_samples=4, device="cpu"):
batch_tensor = torch.randint(
0, 255,
(num_samples, channels, height, width),
dtype=torch.uint8,
device=device
)
return batch_tensor
def compareTensorToPIL(self, tensor, pil_image, msg=None):
np_pil_image = np.array(pil_image)
if np_pil_image.ndim == 2:
......
......@@ -6,7 +6,6 @@ import numpy as np
from PIL.Image import NEAREST, BILINEAR, BICUBIC
import torch
import torchvision.transforms as transforms
import torchvision.transforms.functional_tensor as F_t
import torchvision.transforms.functional_pil as F_pil
import torchvision.transforms.functional as F
......@@ -19,31 +18,47 @@ class Tester(TransformsTester):
def setUp(self):
self.device = "cpu"
def _test_fn_on_batch(self, batch_tensors, fn, **fn_kwargs):
transformed_batch = fn(batch_tensors, **fn_kwargs)
for i in range(len(batch_tensors)):
img_tensor = batch_tensors[i, ...]
transformed_img = fn(img_tensor, **fn_kwargs)
self.assertTrue(transformed_img.equal(transformed_batch[i, ...]))
scripted_fn = torch.jit.script(fn)
# scriptable function test
s_transformed_batch = scripted_fn(batch_tensors, **fn_kwargs)
self.assertTrue(transformed_batch.allclose(s_transformed_batch))
def test_vflip(self):
script_vflip = torch.jit.script(F_t.vflip)
img_tensor = torch.randn(3, 16, 16, device=self.device)
img_tensor_clone = img_tensor.clone()
vflipped_img = F_t.vflip(img_tensor)
vflipped_img_again = F_t.vflip(vflipped_img)
self.assertEqual(vflipped_img.shape, img_tensor.shape)
self.assertTrue(torch.equal(img_tensor, vflipped_img_again))
self.assertTrue(torch.equal(img_tensor, img_tensor_clone))
script_vflip = torch.jit.script(F.vflip)
img_tensor, pil_img = self._create_data(16, 18, device=self.device)
vflipped_img = F.vflip(img_tensor)
vflipped_pil_img = F.vflip(pil_img)
self.compareTensorToPIL(vflipped_img, vflipped_pil_img)
# scriptable function test
vflipped_img_script = script_vflip(img_tensor)
self.assertTrue(torch.equal(vflipped_img, vflipped_img_script))
self.assertTrue(vflipped_img.equal(vflipped_img_script))
batch_tensors = self._create_data_batch(16, 18, num_samples=4, device=self.device)
self._test_fn_on_batch(batch_tensors, F.vflip)
def test_hflip(self):
script_hflip = torch.jit.script(F_t.hflip)
img_tensor = torch.randn(3, 16, 16, device=self.device)
img_tensor_clone = img_tensor.clone()
hflipped_img = F_t.hflip(img_tensor)
hflipped_img_again = F_t.hflip(hflipped_img)
self.assertEqual(hflipped_img.shape, img_tensor.shape)
self.assertTrue(torch.equal(img_tensor, hflipped_img_again))
self.assertTrue(torch.equal(img_tensor, img_tensor_clone))
script_hflip = torch.jit.script(F.hflip)
img_tensor, pil_img = self._create_data(16, 18, device=self.device)
hflipped_img = F.hflip(img_tensor)
hflipped_pil_img = F.hflip(pil_img)
self.compareTensorToPIL(hflipped_img, hflipped_pil_img)
# scriptable function test
hflipped_img_script = script_hflip(img_tensor)
self.assertTrue(torch.equal(hflipped_img, hflipped_img_script))
self.assertTrue(hflipped_img.equal(hflipped_img_script))
batch_tensors = self._create_data_batch(16, 18, num_samples=4, device=self.device)
self._test_fn_on_batch(batch_tensors, F.hflip)
def test_crop(self):
script_crop = torch.jit.script(F.crop)
......@@ -66,6 +81,9 @@ class Tester(TransformsTester):
img_tensor_cropped = script_crop(img_tensor, top, left, height, width)
self.compareTensorToPIL(img_tensor_cropped, pil_img_cropped)
batch_tensors = self._create_data_batch(16, 18, num_samples=4, device=self.device)
self._test_fn_on_batch(batch_tensors, F.crop, top=top, left=left, height=height, width=width)
def test_hsv2rgb(self):
scripted_fn = torch.jit.script(F_t._hsv2rgb)
shape = (3, 100, 150)
......@@ -89,6 +107,9 @@ class Tester(TransformsTester):
s_rgb_img = scripted_fn(hsv_img)
self.assertTrue(rgb_img.allclose(s_rgb_img))
batch_tensors = self._create_data_batch(120, 100, num_samples=4, device=self.device).float()
self._test_fn_on_batch(batch_tensors, F_t._hsv2rgb)
def test_rgb2hsv(self):
scripted_fn = torch.jit.script(F_t._rgb2hsv)
shape = (3, 150, 100)
......@@ -97,7 +118,7 @@ class Tester(TransformsTester):
hsv_img = F_t._rgb2hsv(rgb_img)
ft_hsv_img = hsv_img.permute(1, 2, 0).flatten(0, 1)
r, g, b, = rgb_img.unbind(0)
r, g, b, = rgb_img.unbind(dim=-3)
r = r.flatten().cpu().numpy()
g = g.flatten().cpu().numpy()
b = b.flatten().cpu().numpy()
......@@ -119,6 +140,9 @@ class Tester(TransformsTester):
s_hsv_img = scripted_fn(rgb_img)
self.assertTrue(hsv_img.allclose(s_hsv_img))
batch_tensors = self._create_data_batch(120, 100, num_samples=4, device=self.device).float()
self._test_fn_on_batch(batch_tensors, F_t._rgb2hsv)
def test_rgb_to_grayscale(self):
script_rgb_to_grayscale = torch.jit.script(F.rgb_to_grayscale)
......@@ -128,14 +152,14 @@ class Tester(TransformsTester):
gray_pil_image = F.rgb_to_grayscale(pil_img, num_output_channels=num_output_channels)
gray_tensor = F.rgb_to_grayscale(img_tensor, num_output_channels=num_output_channels)
if num_output_channels == 1:
print(gray_tensor.shape)
self.approxEqualTensorToPIL(gray_tensor.float(), gray_pil_image, tol=1.0 + 1e-10, agg_method="max")
s_gray_tensor = script_rgb_to_grayscale(img_tensor, num_output_channels=num_output_channels)
self.assertTrue(s_gray_tensor.equal(gray_tensor))
batch_tensors = self._create_data_batch(16, 18, num_samples=4, device=self.device)
self._test_fn_on_batch(batch_tensors, F.rgb_to_grayscale, num_output_channels=num_output_channels)
def test_center_crop(self):
script_center_crop = torch.jit.script(F.center_crop)
......@@ -149,6 +173,9 @@ class Tester(TransformsTester):
cropped_tensor = script_center_crop(img_tensor, [10, 11])
self.compareTensorToPIL(cropped_tensor, cropped_pil_image)
batch_tensors = self._create_data_batch(16, 18, num_samples=4, device=self.device)
self._test_fn_on_batch(batch_tensors, F.center_crop, output_size=[10, 11])
def test_five_crop(self):
script_five_crop = torch.jit.script(F.five_crop)
......@@ -164,6 +191,23 @@ class Tester(TransformsTester):
for i in range(5):
self.compareTensorToPIL(cropped_tensors[i], cropped_pil_images[i])
batch_tensors = self._create_data_batch(16, 18, num_samples=4, device=self.device)
tuple_transformed_batches = F.five_crop(batch_tensors, [10, 11])
for i in range(len(batch_tensors)):
img_tensor = batch_tensors[i, ...]
tuple_transformed_imgs = F.five_crop(img_tensor, [10, 11])
self.assertEqual(len(tuple_transformed_imgs), len(tuple_transformed_batches))
for j in range(len(tuple_transformed_imgs)):
true_transformed_img = tuple_transformed_imgs[j]
transformed_img = tuple_transformed_batches[j][i, ...]
self.assertTrue(true_transformed_img.equal(transformed_img))
# scriptable function test
s_tuple_transformed_batches = script_five_crop(batch_tensors, [10, 11])
for transformed_batch, s_transformed_batch in zip(tuple_transformed_batches, s_tuple_transformed_batches):
self.assertTrue(transformed_batch.equal(s_transformed_batch))
def test_ten_crop(self):
script_ten_crop = torch.jit.script(F.ten_crop)
......@@ -179,9 +223,27 @@ class Tester(TransformsTester):
for i in range(10):
self.compareTensorToPIL(cropped_tensors[i], cropped_pil_images[i])
batch_tensors = self._create_data_batch(16, 18, num_samples=4, device=self.device)
tuple_transformed_batches = F.ten_crop(batch_tensors, [10, 11])
for i in range(len(batch_tensors)):
img_tensor = batch_tensors[i, ...]
tuple_transformed_imgs = F.ten_crop(img_tensor, [10, 11])
self.assertEqual(len(tuple_transformed_imgs), len(tuple_transformed_batches))
for j in range(len(tuple_transformed_imgs)):
true_transformed_img = tuple_transformed_imgs[j]
transformed_img = tuple_transformed_batches[j][i, ...]
self.assertTrue(true_transformed_img.equal(transformed_img))
# scriptable function test
s_tuple_transformed_batches = script_ten_crop(batch_tensors, [10, 11])
for transformed_batch, s_transformed_batch in zip(tuple_transformed_batches, s_tuple_transformed_batches):
self.assertTrue(transformed_batch.equal(s_transformed_batch))
def test_pad(self):
script_fn = torch.jit.script(F_t.pad)
script_fn = torch.jit.script(F.pad)
tensor, pil_img = self._create_data(7, 8, device=self.device)
batch_tensors = self._create_data_batch(16, 18, num_samples=4, device=self.device)
for dt in [None, torch.float32, torch.float64, torch.float16]:
......@@ -192,6 +254,8 @@ class Tester(TransformsTester):
if dt is not None:
# This is a trivial cast to float of uint8 data to test all cases
tensor = tensor.to(dt)
batch_tensors = batch_tensors.to(dt)
for pad in [2, [3, ], [0, 3], (3, 3), [4, 2, 4, 3]]:
configs = [
{"padding_mode": "constant", "fill": 0},
......@@ -219,6 +283,8 @@ class Tester(TransformsTester):
pad_tensor_script = script_fn(tensor, script_pad, **kwargs)
self.assertTrue(pad_tensor.equal(pad_tensor_script), msg="{}, {}".format(pad, kwargs))
self._test_fn_on_batch(batch_tensors, F.pad, padding=script_pad, **kwargs)
with self.assertRaises(ValueError, msg="Padding can not be negative for symmetric padding_mode"):
F_t.pad(tensor, (-2, -3), padding_mode="symmetric")
......@@ -226,11 +292,13 @@ class Tester(TransformsTester):
script_fn = torch.jit.script(fn)
torch.manual_seed(15)
tensor, pil_img = self._create_data(26, 34, device=self.device)
batch_tensors = self._create_data_batch(16, 18, num_samples=4, device=self.device)
for dt in [None, torch.float32, torch.float64]:
if dt is not None:
tensor = F.convert_image_dtype(tensor, dt)
batch_tensors = F.convert_image_dtype(batch_tensors, dt)
for config in configs:
adjusted_tensor = fn_t(tensor, **config)
......@@ -254,6 +322,8 @@ class Tester(TransformsTester):
atol = 1.0
self.assertTrue(adjusted_tensor.allclose(scripted_result, atol=atol), msg=msg)
self._test_fn_on_batch(batch_tensors, fn, **config)
def test_adjust_brightness(self):
self._test_adjust_fn(
F.adjust_brightness,
......@@ -299,6 +369,7 @@ class Tester(TransformsTester):
def test_resize(self):
script_fn = torch.jit.script(F_t.resize)
tensor, pil_img = self._create_data(26, 36, device=self.device)
batch_tensors = self._create_data_batch(16, 18, num_samples=4, device=self.device)
for dt in [None, torch.float32, torch.float64, torch.float16]:
......@@ -309,6 +380,8 @@ class Tester(TransformsTester):
if dt is not None:
# This is a trivial cast to float of uint8 data to test all cases
tensor = tensor.to(dt)
batch_tensors = batch_tensors.to(dt)
for size in [32, 26, [32, ], [32, 32], (32, 32), [26, 35]]:
for interpolation in [BILINEAR, BICUBIC, NEAREST]:
resized_tensor = F_t.resize(tensor, size=size, interpolation=interpolation)
......@@ -339,6 +412,10 @@ class Tester(TransformsTester):
resize_result = script_fn(tensor, size=script_size, interpolation=interpolation)
self.assertTrue(resized_tensor.equal(resize_result), msg="{}, {}".format(size, interpolation))
self._test_fn_on_batch(
batch_tensors, F.resize, size=script_size, interpolation=interpolation
)
def test_resized_crop(self):
# test values of F.resized_crop in several cases:
# 1) resize to the same size, crop to the same size => should be identity
......@@ -356,6 +433,11 @@ class Tester(TransformsTester):
msg="{} vs {}".format(expected_out_tensor[0, :10, :10], out_tensor[0, :10, :10])
)
batch_tensors = self._create_data_batch(26, 36, num_samples=4, device=self.device)
self._test_fn_on_batch(
batch_tensors, F.resized_crop, top=1, left=2, height=20, width=30, size=[10, 15], interpolation=0
)
def _test_affine_identity_map(self, tensor, scripted_affine):
# 1) identity map
out_tensor = F.affine(tensor, angle=0, translate=[0, 0], scale=1.0, shear=[0.0, 0.0], resample=0)
......@@ -515,31 +597,19 @@ class Tester(TransformsTester):
else:
self._test_affine_rect_rotations(tensor, pil_img, scripted_affine)
self._test_affine_translations(tensor, pil_img, scripted_affine)
# self._test_affine_all_ops(tensor, pil_img, scripted_affine)
self._test_affine_all_ops(tensor, pil_img, scripted_affine)
def test_rotate(self):
# Tests on square image
scripted_rotate = torch.jit.script(F.rotate)
batch_tensors = self._create_data_batch(26, 36, num_samples=4, device=self.device)
if dt is not None:
batch_tensors = batch_tensors.to(dtype=dt)
data = [self._create_data(26, 26, device=self.device), self._create_data(32, 26, device=self.device)]
for tensor, pil_img in data:
self._test_fn_on_batch(
batch_tensors, F.affine, angle=-43, translate=[-3, 4], scale=1.2, shear=[4.0, 5.0]
)
def _test_rotate_all_options(self, tensor, pil_img, scripted_rotate, centers):
img_size = pil_img.size
centers = [
None,
(int(img_size[0] * 0.3), int(img_size[0] * 0.4)),
[int(img_size[0] * 0.5), int(img_size[0] * 0.6)]
]
for dt in [None, torch.float32, torch.float64, torch.float16]:
if dt == torch.float16 and torch.device(self.device).type == "cpu":
# skip float16 on CPU case
continue
if dt is not None:
tensor = tensor.to(dtype=dt)
dt = tensor.dtype
for r in [0, ]:
for a in range(-180, 180, 17):
for e in [True, False]:
......@@ -574,23 +644,18 @@ class Tester(TransformsTester):
)
)
def test_perspective(self):
from torchvision.transforms import RandomPerspective
data = [self._create_data(26, 34, device=self.device), self._create_data(26, 26, device=self.device)]
scripted_tranform = torch.jit.script(F.perspective)
def test_rotate(self):
# Tests on square image
scripted_rotate = torch.jit.script(F.rotate)
data = [self._create_data(26, 26, device=self.device), self._create_data(32, 26, device=self.device)]
for tensor, pil_img in data:
test_configs = [
[[[0, 0], [33, 0], [33, 25], [0, 25]], [[3, 2], [32, 3], [30, 24], [2, 25]]],
[[[3, 2], [32, 3], [30, 24], [2, 25]], [[0, 0], [33, 0], [33, 25], [0, 25]]],
[[[3, 2], [32, 3], [30, 24], [2, 25]], [[5, 5], [30, 3], [33, 19], [4, 25]]],
]
n = 10
test_configs += [
RandomPerspective.get_params(pil_img.size[0], pil_img.size[1], i / n) for i in range(n)
img_size = pil_img.size
centers = [
None,
(int(img_size[0] * 0.3), int(img_size[0] * 0.4)),
[int(img_size[0] * 0.5), int(img_size[0] * 0.6)]
]
for dt in [None, torch.float32, torch.float64, torch.float16]:
......@@ -602,6 +667,19 @@ class Tester(TransformsTester):
if dt is not None:
tensor = tensor.to(dtype=dt)
self._test_rotate_all_options(tensor, pil_img, scripted_rotate, centers)
batch_tensors = self._create_data_batch(26, 36, num_samples=4, device=self.device)
if dt is not None:
batch_tensors = batch_tensors.to(dtype=dt)
center = (20, 22)
self._test_fn_on_batch(
batch_tensors, F.rotate, angle=32, resample=0, expand=True, center=center
)
def _test_perspective(self, tensor, pil_img, scripted_tranform, test_configs):
dt = tensor.dtype
for r in [0, ]:
for spoints, epoints in test_configs:
out_pil_img = F.perspective(pil_img, startpoints=spoints, endpoints=epoints, interpolation=r)
......@@ -627,6 +705,45 @@ class Tester(TransformsTester):
)
)
def test_perspective(self):
from torchvision.transforms import RandomPerspective
data = [self._create_data(26, 34, device=self.device), self._create_data(26, 26, device=self.device)]
scripted_tranform = torch.jit.script(F.perspective)
for tensor, pil_img in data:
test_configs = [
[[[0, 0], [33, 0], [33, 25], [0, 25]], [[3, 2], [32, 3], [30, 24], [2, 25]]],
[[[3, 2], [32, 3], [30, 24], [2, 25]], [[0, 0], [33, 0], [33, 25], [0, 25]]],
[[[3, 2], [32, 3], [30, 24], [2, 25]], [[5, 5], [30, 3], [33, 19], [4, 25]]],
]
n = 10
test_configs += [
RandomPerspective.get_params(pil_img.size[0], pil_img.size[1], i / n) for i in range(n)
]
for dt in [None, torch.float32, torch.float64, torch.float16]:
if dt == torch.float16 and torch.device(self.device).type == "cpu":
# skip float16 on CPU case
continue
if dt is not None:
tensor = tensor.to(dtype=dt)
self._test_perspective(tensor, pil_img, scripted_tranform, test_configs)
batch_tensors = self._create_data_batch(26, 36, num_samples=4, device=self.device)
if dt is not None:
batch_tensors = batch_tensors.to(dtype=dt)
for spoints, epoints in test_configs:
self._test_fn_on_batch(
batch_tensors, F.perspective, startpoints=spoints, endpoints=epoints, interpolation=0
)
@unittest.skipIf(not torch.cuda.is_available(), reason="Skip if no CUDA device")
class CUDATester(Tester):
......
......@@ -19,20 +19,43 @@ class Tester(TransformsTester):
def _test_functional_op(self, func, fn_kwargs):
if fn_kwargs is None:
fn_kwargs = {}
f = getattr(F, func)
tensor, pil_img = self._create_data(height=10, width=10, device=self.device)
transformed_tensor = getattr(F, func)(tensor, **fn_kwargs)
transformed_pil_img = getattr(F, func)(pil_img, **fn_kwargs)
transformed_tensor = f(tensor, **fn_kwargs)
transformed_pil_img = f(pil_img, **fn_kwargs)
self.compareTensorToPIL(transformed_tensor, transformed_pil_img)
def _test_transform_vs_scripted(self, transform, s_transform, tensor):
torch.manual_seed(12)
out1 = transform(tensor)
torch.manual_seed(12)
out2 = s_transform(tensor)
self.assertTrue(out1.equal(out2))
def _test_transform_vs_scripted_on_batch(self, transform, s_transform, batch_tensors):
torch.manual_seed(12)
transformed_batch = transform(batch_tensors)
for i in range(len(batch_tensors)):
img_tensor = batch_tensors[i, ...]
torch.manual_seed(12)
transformed_img = transform(img_tensor)
self.assertTrue(transformed_img.equal(transformed_batch[i, ...]))
torch.manual_seed(12)
s_transformed_batch = s_transform(batch_tensors)
self.assertTrue(transformed_batch.equal(s_transformed_batch))
def _test_class_op(self, method, meth_kwargs=None, test_exact_match=True, **match_kwargs):
if meth_kwargs is None:
meth_kwargs = {}
tensor, pil_img = self._create_data(26, 34, device=self.device)
# test for class interface
f = getattr(T, method)(**meth_kwargs)
scripted_fn = torch.jit.script(f)
tensor, pil_img = self._create_data(26, 34, device=self.device)
# set seed to reproduce the same transformation for tensor and PIL image
torch.manual_seed(12)
transformed_tensor = f(tensor)
......@@ -47,6 +70,9 @@ class Tester(TransformsTester):
transformed_tensor_script = scripted_fn(tensor)
self.assertTrue(transformed_tensor.equal(transformed_tensor_script))
batch_tensors = self._create_data_batch(height=23, width=34, channels=3, num_samples=4, device=self.device)
self._test_transform_vs_scripted_on_batch(f, scripted_fn, batch_tensors)
def _test_op(self, func, method, fn_kwargs=None, meth_kwargs=None):
self._test_functional_op(func, fn_kwargs)
self._test_class_op(method, meth_kwargs)
......@@ -167,15 +193,18 @@ class Tester(TransformsTester):
fn_kwargs = {}
if meth_kwargs is None:
meth_kwargs = {}
fn = getattr(F, func)
scripted_fn = torch.jit.script(fn)
tensor, pil_img = self._create_data(height=20, width=20, device=self.device)
transformed_t_list = getattr(F, func)(tensor, **fn_kwargs)
transformed_p_list = getattr(F, func)(pil_img, **fn_kwargs)
transformed_t_list = fn(tensor, **fn_kwargs)
transformed_p_list = fn(pil_img, **fn_kwargs)
self.assertEqual(len(transformed_t_list), len(transformed_p_list))
self.assertEqual(len(transformed_t_list), out_length)
for transformed_tensor, transformed_pil_img in zip(transformed_t_list, transformed_p_list):
self.compareTensorToPIL(transformed_tensor, transformed_pil_img)
scripted_fn = torch.jit.script(getattr(F, func))
transformed_t_list_script = scripted_fn(tensor.detach().clone(), **fn_kwargs)
self.assertEqual(len(transformed_t_list), len(transformed_t_list_script))
self.assertEqual(len(transformed_t_list_script), out_length)
......@@ -184,11 +213,24 @@ class Tester(TransformsTester):
msg="{} vs {}".format(transformed_tensor, transformed_tensor_script))
# test for class interface
f = getattr(T, method)(**meth_kwargs)
scripted_fn = torch.jit.script(f)
fn = getattr(T, method)(**meth_kwargs)
scripted_fn = torch.jit.script(fn)
output = scripted_fn(tensor)
self.assertEqual(len(output), len(transformed_t_list_script))
# test on batch of tensors
batch_tensors = self._create_data_batch(height=23, width=34, channels=3, num_samples=4, device=self.device)
torch.manual_seed(12)
transformed_batch_list = fn(batch_tensors)
for i in range(len(batch_tensors)):
img_tensor = batch_tensors[i, ...]
torch.manual_seed(12)
transformed_img_list = fn(img_tensor)
for transformed_img, transformed_batch in zip(transformed_img_list, transformed_batch_list):
self.assertTrue(transformed_img.equal(transformed_batch[i, ...]),
msg="{} vs {}".format(transformed_img, transformed_batch[i, ...]))
def test_five_crop(self):
fn_kwargs = meth_kwargs = {"size": (5,)}
self._test_op_list_output(
......@@ -227,6 +269,7 @@ class Tester(TransformsTester):
def test_resize(self):
tensor, _ = self._create_data(height=34, width=36, device=self.device)
batch_tensors = torch.randint(0, 255, size=(4, 3, 44, 56), dtype=torch.uint8, device=self.device)
script_fn = torch.jit.script(F.resize)
for dt in [None, torch.float32, torch.float64]:
......@@ -247,13 +290,13 @@ class Tester(TransformsTester):
self.assertTrue(s_resized_tensor.equal(resized_tensor))
transform = T.Resize(size=script_size, interpolation=interpolation)
resized_tensor = transform(tensor)
script_transform = torch.jit.script(transform)
s_resized_tensor = script_transform(tensor)
self.assertTrue(s_resized_tensor.equal(resized_tensor))
s_transform = torch.jit.script(transform)
self._test_transform_vs_scripted(transform, s_transform, tensor)
self._test_transform_vs_scripted_on_batch(transform, s_transform, batch_tensors)
def test_resized_crop(self):
tensor = torch.randint(0, 255, size=(3, 44, 56), dtype=torch.uint8, device=self.device)
batch_tensors = torch.randint(0, 255, size=(4, 3, 44, 56), dtype=torch.uint8, device=self.device)
for scale in [(0.7, 1.2), [0.7, 1.2]]:
for ratio in [(0.75, 1.333), [0.75, 1.333]]:
......@@ -263,15 +306,12 @@ class Tester(TransformsTester):
size=size, scale=scale, ratio=ratio, interpolation=interpolation
)
s_transform = torch.jit.script(transform)
torch.manual_seed(12)
out1 = transform(tensor)
torch.manual_seed(12)
out2 = s_transform(tensor)
self.assertTrue(out1.equal(out2))
self._test_transform_vs_scripted(transform, s_transform, tensor)
self._test_transform_vs_scripted_on_batch(transform, s_transform, batch_tensors)
def test_random_affine(self):
tensor = torch.randint(0, 255, size=(3, 44, 56), dtype=torch.uint8, device=self.device)
batch_tensors = torch.randint(0, 255, size=(4, 3, 44, 56), dtype=torch.uint8, device=self.device)
for shear in [15, 10.0, (5.0, 10.0), [-15, 15], [-10.0, 10.0, -11.0, 11.0]]:
for scale in [(0.7, 1.2), [0.7, 1.2]]:
......@@ -284,14 +324,12 @@ class Tester(TransformsTester):
)
s_transform = torch.jit.script(transform)
torch.manual_seed(12)
out1 = transform(tensor)
torch.manual_seed(12)
out2 = s_transform(tensor)
self.assertTrue(out1.equal(out2))
self._test_transform_vs_scripted(transform, s_transform, tensor)
self._test_transform_vs_scripted_on_batch(transform, s_transform, batch_tensors)
def test_random_rotate(self):
tensor = torch.randint(0, 255, size=(3, 44, 56), dtype=torch.uint8, device=self.device)
batch_tensors = torch.randint(0, 255, size=(4, 3, 44, 56), dtype=torch.uint8, device=self.device)
for center in [(0, 0), [10, 10], None, (56, 44)]:
for expand in [True, False]:
......@@ -302,14 +340,12 @@ class Tester(TransformsTester):
)
s_transform = torch.jit.script(transform)
torch.manual_seed(12)
out1 = transform(tensor)
torch.manual_seed(12)
out2 = s_transform(tensor)
self.assertTrue(out1.equal(out2))
self._test_transform_vs_scripted(transform, s_transform, tensor)
self._test_transform_vs_scripted_on_batch(transform, s_transform, batch_tensors)
def test_random_perspective(self):
tensor = torch.randint(0, 255, size=(3, 44, 56), dtype=torch.uint8, device=self.device)
batch_tensors = torch.randint(0, 255, size=(4, 3, 44, 56), dtype=torch.uint8, device=self.device)
for distortion_scale in np.linspace(0.1, 1.0, num=20):
for interpolation in [NEAREST, BILINEAR]:
......@@ -319,11 +355,8 @@ class Tester(TransformsTester):
)
s_transform = torch.jit.script(transform)
torch.manual_seed(12)
out1 = transform(tensor)
torch.manual_seed(12)
out2 = s_transform(tensor)
self.assertTrue(out1.equal(out2))
self._test_transform_vs_scripted(transform, s_transform, tensor)
self._test_transform_vs_scripted_on_batch(transform, s_transform, batch_tensors)
def test_to_grayscale(self):
......
......@@ -36,7 +36,7 @@ def vflip(img: Tensor) -> Tensor:
Please, consider instead using methods from `transforms.functional` module.
Args:
img (Tensor): Image Tensor to be flipped in the form [C, H, W].
img (Tensor): Image Tensor to be flipped in the form [..., C, H, W].
Returns:
Tensor: Vertically flipped image Tensor.
......@@ -56,7 +56,7 @@ def hflip(img: Tensor) -> Tensor:
Please, consider instead using methods from `transforms.functional` module.
Args:
img (Tensor): Image Tensor to be flipped in the form [C, H, W].
img (Tensor): Image Tensor to be flipped in the form [..., C, H, W].
Returns:
Tensor: Horizontally flipped image Tensor.
......@@ -183,7 +183,8 @@ def adjust_contrast(img: Tensor, contrast_factor: float) -> Tensor:
if not _is_tensor_a_torch_image(img):
raise TypeError('tensor is not a torch image.')
mean = torch.mean(rgb_to_grayscale(img).to(torch.float))
dtype = img.dtype if torch.is_floating_point(img) else torch.float32
mean = torch.mean(rgb_to_grayscale(img).to(dtype), dim=(-3, -2, -1), keepdim=True)
return _blend(img, mean, contrast_factor)
......@@ -229,9 +230,9 @@ def adjust_hue(img: Tensor, hue_factor: float) -> Tensor:
img = img.to(dtype=torch.float32) / 255.0
img = _rgb2hsv(img)
h, s, v = img.unbind(0)
h, s, v = img.unbind(dim=-3)
h = (h + hue_factor) % 1.0
img = torch.stack((h, s, v))
img = torch.stack((h, s, v), dim=-3)
img_hue_adj = _hsv2rgb(img)
if orig_dtype == torch.uint8:
......@@ -466,12 +467,12 @@ def _blend(img1: Tensor, img2: Tensor, ratio: float) -> Tensor:
def _rgb2hsv(img):
r, g, b = img.unbind(0)
r, g, b = img.unbind(dim=-3)
# Implementation is based on https://github.com/python-pillow/Pillow/blob/4174d4267616897df3746d315d5a2d0f82c656ee/
# src/libImaging/Convert.c#L330
maxc = torch.max(img, dim=0).values
minc = torch.min(img, dim=0).values
maxc = torch.max(img, dim=-3).values
minc = torch.min(img, dim=-3).values
# The algorithm erases S and H channel where `maxc = minc`. This avoids NaN
# from happening in the results, because
......@@ -501,11 +502,11 @@ def _rgb2hsv(img):
hb = ((maxc != g) & (maxc != r)) * (4.0 + gc - rc)
h = (hr + hg + hb)
h = torch.fmod((h / 6.0 + 1.0), 1.0)
return torch.stack((h, s, maxc))
return torch.stack((h, s, maxc), dim=-3)
def _hsv2rgb(img):
h, s, v = img.unbind(0)
h, s, v = img.unbind(dim=-3)
i = torch.floor(h * 6.0)
f = (h * 6.0) - i
i = i.to(dtype=torch.int32)
......@@ -515,14 +516,14 @@ def _hsv2rgb(img):
t = torch.clamp((v * (1.0 - s * (1.0 - f))), 0.0, 1.0)
i = i % 6
mask = i == torch.arange(6, device=i.device)[:, None, None]
mask = i.unsqueeze(dim=-3) == torch.arange(6, device=i.device).view(-1, 1, 1)
a1 = torch.stack((v, q, p, p, t, v))
a2 = torch.stack((t, v, v, q, p, p))
a3 = torch.stack((p, p, t, v, v, q))
a4 = torch.stack((a1, a2, a3))
a1 = torch.stack((v, q, p, p, t, v), dim=-3)
a2 = torch.stack((t, v, v, q, p, p), dim=-3)
a3 = torch.stack((p, p, t, v, v, q), dim=-3)
a4 = torch.stack((a1, a2, a3), dim=-4)
return torch.einsum("ijk, xijk -> xjk", mask.to(dtype=img.dtype), a4)
return torch.einsum("...ijk, ...xijk -> ...xjk", mask.to(dtype=img.dtype), a4)
def _pad_symmetric(img: Tensor, padding: List[int]) -> Tensor:
......@@ -793,6 +794,9 @@ def _apply_grid_transform(img: Tensor, grid: Tensor, mode: str) -> Tensor:
need_cast = True
img = img.to(grid)
if img.shape[0] > 1:
# Apply same grid to a batch of images
grid = grid.expand(img.shape[0], grid.shape[1], grid.shape[2], grid.shape[3])
img = grid_sample(img, grid, mode=mode, padding_mode="zeros", align_corners=False)
if need_squeeze:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment