Unverified Commit c4dcfb06 authored by vfdev's avatar vfdev Committed by GitHub
Browse files

Added tests on batch of tensors to check transforms (#2584)

* [WIP] Added tests on batch of tensors

* Updated tests on batch of images

* All functional transforms can work with (..., C, H, W) format

* Added transforms tests on batch tensors

* Added batch tests for five/ten crop
- updated docs
parent c36dc43b
...@@ -9,6 +9,11 @@ Functional transforms give fine-grained control over the transformations. ...@@ -9,6 +9,11 @@ Functional transforms give fine-grained control over the transformations.
This is useful if you have to build a more complex transformation pipeline This is useful if you have to build a more complex transformation pipeline
(e.g. in the case of segmentation tasks). (e.g. in the case of segmentation tasks).
All transformations accept PIL Image, Tensor Image or batch of Tensor Images as input. Tensor Image is a tensor with
``(C, H, W)`` shape, where ``C`` is a number of channels, ``H`` and ``W`` are image height and width. Batch of
Tensor Images is a tensor of ``(B, C, H, W)`` shape, where ``B`` is a number of images in the batch. Deterministic or
random transformations applied on the batch of Tensor Images identically transform all the images of the batch.
.. autoclass:: Compose .. autoclass:: Compose
Transforms on PIL Image Transforms on PIL Image
......
...@@ -341,6 +341,15 @@ class TransformsTester(unittest.TestCase): ...@@ -341,6 +341,15 @@ class TransformsTester(unittest.TestCase):
pil_img = Image.fromarray(tensor.permute(1, 2, 0).contiguous().cpu().numpy()) pil_img = Image.fromarray(tensor.permute(1, 2, 0).contiguous().cpu().numpy())
return tensor, pil_img return tensor, pil_img
def _create_data_batch(self, height=3, width=3, channels=3, num_samples=4, device="cpu"):
batch_tensor = torch.randint(
0, 255,
(num_samples, channels, height, width),
dtype=torch.uint8,
device=device
)
return batch_tensor
def compareTensorToPIL(self, tensor, pil_image, msg=None): def compareTensorToPIL(self, tensor, pil_image, msg=None):
np_pil_image = np.array(pil_image) np_pil_image = np.array(pil_image)
if np_pil_image.ndim == 2: if np_pil_image.ndim == 2:
......
This diff is collapsed.
...@@ -19,20 +19,43 @@ class Tester(TransformsTester): ...@@ -19,20 +19,43 @@ class Tester(TransformsTester):
def _test_functional_op(self, func, fn_kwargs): def _test_functional_op(self, func, fn_kwargs):
if fn_kwargs is None: if fn_kwargs is None:
fn_kwargs = {} fn_kwargs = {}
f = getattr(F, func)
tensor, pil_img = self._create_data(height=10, width=10, device=self.device) tensor, pil_img = self._create_data(height=10, width=10, device=self.device)
transformed_tensor = getattr(F, func)(tensor, **fn_kwargs) transformed_tensor = f(tensor, **fn_kwargs)
transformed_pil_img = getattr(F, func)(pil_img, **fn_kwargs) transformed_pil_img = f(pil_img, **fn_kwargs)
self.compareTensorToPIL(transformed_tensor, transformed_pil_img) self.compareTensorToPIL(transformed_tensor, transformed_pil_img)
def _test_transform_vs_scripted(self, transform, s_transform, tensor):
torch.manual_seed(12)
out1 = transform(tensor)
torch.manual_seed(12)
out2 = s_transform(tensor)
self.assertTrue(out1.equal(out2))
def _test_transform_vs_scripted_on_batch(self, transform, s_transform, batch_tensors):
torch.manual_seed(12)
transformed_batch = transform(batch_tensors)
for i in range(len(batch_tensors)):
img_tensor = batch_tensors[i, ...]
torch.manual_seed(12)
transformed_img = transform(img_tensor)
self.assertTrue(transformed_img.equal(transformed_batch[i, ...]))
torch.manual_seed(12)
s_transformed_batch = s_transform(batch_tensors)
self.assertTrue(transformed_batch.equal(s_transformed_batch))
def _test_class_op(self, method, meth_kwargs=None, test_exact_match=True, **match_kwargs): def _test_class_op(self, method, meth_kwargs=None, test_exact_match=True, **match_kwargs):
if meth_kwargs is None: if meth_kwargs is None:
meth_kwargs = {} meth_kwargs = {}
tensor, pil_img = self._create_data(26, 34, device=self.device)
# test for class interface # test for class interface
f = getattr(T, method)(**meth_kwargs) f = getattr(T, method)(**meth_kwargs)
scripted_fn = torch.jit.script(f) scripted_fn = torch.jit.script(f)
tensor, pil_img = self._create_data(26, 34, device=self.device)
# set seed to reproduce the same transformation for tensor and PIL image # set seed to reproduce the same transformation for tensor and PIL image
torch.manual_seed(12) torch.manual_seed(12)
transformed_tensor = f(tensor) transformed_tensor = f(tensor)
...@@ -47,6 +70,9 @@ class Tester(TransformsTester): ...@@ -47,6 +70,9 @@ class Tester(TransformsTester):
transformed_tensor_script = scripted_fn(tensor) transformed_tensor_script = scripted_fn(tensor)
self.assertTrue(transformed_tensor.equal(transformed_tensor_script)) self.assertTrue(transformed_tensor.equal(transformed_tensor_script))
batch_tensors = self._create_data_batch(height=23, width=34, channels=3, num_samples=4, device=self.device)
self._test_transform_vs_scripted_on_batch(f, scripted_fn, batch_tensors)
def _test_op(self, func, method, fn_kwargs=None, meth_kwargs=None): def _test_op(self, func, method, fn_kwargs=None, meth_kwargs=None):
self._test_functional_op(func, fn_kwargs) self._test_functional_op(func, fn_kwargs)
self._test_class_op(method, meth_kwargs) self._test_class_op(method, meth_kwargs)
...@@ -167,15 +193,18 @@ class Tester(TransformsTester): ...@@ -167,15 +193,18 @@ class Tester(TransformsTester):
fn_kwargs = {} fn_kwargs = {}
if meth_kwargs is None: if meth_kwargs is None:
meth_kwargs = {} meth_kwargs = {}
fn = getattr(F, func)
scripted_fn = torch.jit.script(fn)
tensor, pil_img = self._create_data(height=20, width=20, device=self.device) tensor, pil_img = self._create_data(height=20, width=20, device=self.device)
transformed_t_list = getattr(F, func)(tensor, **fn_kwargs) transformed_t_list = fn(tensor, **fn_kwargs)
transformed_p_list = getattr(F, func)(pil_img, **fn_kwargs) transformed_p_list = fn(pil_img, **fn_kwargs)
self.assertEqual(len(transformed_t_list), len(transformed_p_list)) self.assertEqual(len(transformed_t_list), len(transformed_p_list))
self.assertEqual(len(transformed_t_list), out_length) self.assertEqual(len(transformed_t_list), out_length)
for transformed_tensor, transformed_pil_img in zip(transformed_t_list, transformed_p_list): for transformed_tensor, transformed_pil_img in zip(transformed_t_list, transformed_p_list):
self.compareTensorToPIL(transformed_tensor, transformed_pil_img) self.compareTensorToPIL(transformed_tensor, transformed_pil_img)
scripted_fn = torch.jit.script(getattr(F, func))
transformed_t_list_script = scripted_fn(tensor.detach().clone(), **fn_kwargs) transformed_t_list_script = scripted_fn(tensor.detach().clone(), **fn_kwargs)
self.assertEqual(len(transformed_t_list), len(transformed_t_list_script)) self.assertEqual(len(transformed_t_list), len(transformed_t_list_script))
self.assertEqual(len(transformed_t_list_script), out_length) self.assertEqual(len(transformed_t_list_script), out_length)
...@@ -184,11 +213,24 @@ class Tester(TransformsTester): ...@@ -184,11 +213,24 @@ class Tester(TransformsTester):
msg="{} vs {}".format(transformed_tensor, transformed_tensor_script)) msg="{} vs {}".format(transformed_tensor, transformed_tensor_script))
# test for class interface # test for class interface
f = getattr(T, method)(**meth_kwargs) fn = getattr(T, method)(**meth_kwargs)
scripted_fn = torch.jit.script(f) scripted_fn = torch.jit.script(fn)
output = scripted_fn(tensor) output = scripted_fn(tensor)
self.assertEqual(len(output), len(transformed_t_list_script)) self.assertEqual(len(output), len(transformed_t_list_script))
# test on batch of tensors
batch_tensors = self._create_data_batch(height=23, width=34, channels=3, num_samples=4, device=self.device)
torch.manual_seed(12)
transformed_batch_list = fn(batch_tensors)
for i in range(len(batch_tensors)):
img_tensor = batch_tensors[i, ...]
torch.manual_seed(12)
transformed_img_list = fn(img_tensor)
for transformed_img, transformed_batch in zip(transformed_img_list, transformed_batch_list):
self.assertTrue(transformed_img.equal(transformed_batch[i, ...]),
msg="{} vs {}".format(transformed_img, transformed_batch[i, ...]))
def test_five_crop(self): def test_five_crop(self):
fn_kwargs = meth_kwargs = {"size": (5,)} fn_kwargs = meth_kwargs = {"size": (5,)}
self._test_op_list_output( self._test_op_list_output(
...@@ -227,6 +269,7 @@ class Tester(TransformsTester): ...@@ -227,6 +269,7 @@ class Tester(TransformsTester):
def test_resize(self): def test_resize(self):
tensor, _ = self._create_data(height=34, width=36, device=self.device) tensor, _ = self._create_data(height=34, width=36, device=self.device)
batch_tensors = torch.randint(0, 255, size=(4, 3, 44, 56), dtype=torch.uint8, device=self.device)
script_fn = torch.jit.script(F.resize) script_fn = torch.jit.script(F.resize)
for dt in [None, torch.float32, torch.float64]: for dt in [None, torch.float32, torch.float64]:
...@@ -247,13 +290,13 @@ class Tester(TransformsTester): ...@@ -247,13 +290,13 @@ class Tester(TransformsTester):
self.assertTrue(s_resized_tensor.equal(resized_tensor)) self.assertTrue(s_resized_tensor.equal(resized_tensor))
transform = T.Resize(size=script_size, interpolation=interpolation) transform = T.Resize(size=script_size, interpolation=interpolation)
resized_tensor = transform(tensor) s_transform = torch.jit.script(transform)
script_transform = torch.jit.script(transform) self._test_transform_vs_scripted(transform, s_transform, tensor)
s_resized_tensor = script_transform(tensor) self._test_transform_vs_scripted_on_batch(transform, s_transform, batch_tensors)
self.assertTrue(s_resized_tensor.equal(resized_tensor))
def test_resized_crop(self): def test_resized_crop(self):
tensor = torch.randint(0, 255, size=(3, 44, 56), dtype=torch.uint8, device=self.device) tensor = torch.randint(0, 255, size=(3, 44, 56), dtype=torch.uint8, device=self.device)
batch_tensors = torch.randint(0, 255, size=(4, 3, 44, 56), dtype=torch.uint8, device=self.device)
for scale in [(0.7, 1.2), [0.7, 1.2]]: for scale in [(0.7, 1.2), [0.7, 1.2]]:
for ratio in [(0.75, 1.333), [0.75, 1.333]]: for ratio in [(0.75, 1.333), [0.75, 1.333]]:
...@@ -263,15 +306,12 @@ class Tester(TransformsTester): ...@@ -263,15 +306,12 @@ class Tester(TransformsTester):
size=size, scale=scale, ratio=ratio, interpolation=interpolation size=size, scale=scale, ratio=ratio, interpolation=interpolation
) )
s_transform = torch.jit.script(transform) s_transform = torch.jit.script(transform)
self._test_transform_vs_scripted(transform, s_transform, tensor)
torch.manual_seed(12) self._test_transform_vs_scripted_on_batch(transform, s_transform, batch_tensors)
out1 = transform(tensor)
torch.manual_seed(12)
out2 = s_transform(tensor)
self.assertTrue(out1.equal(out2))
def test_random_affine(self): def test_random_affine(self):
tensor = torch.randint(0, 255, size=(3, 44, 56), dtype=torch.uint8, device=self.device) tensor = torch.randint(0, 255, size=(3, 44, 56), dtype=torch.uint8, device=self.device)
batch_tensors = torch.randint(0, 255, size=(4, 3, 44, 56), dtype=torch.uint8, device=self.device)
for shear in [15, 10.0, (5.0, 10.0), [-15, 15], [-10.0, 10.0, -11.0, 11.0]]: for shear in [15, 10.0, (5.0, 10.0), [-15, 15], [-10.0, 10.0, -11.0, 11.0]]:
for scale in [(0.7, 1.2), [0.7, 1.2]]: for scale in [(0.7, 1.2), [0.7, 1.2]]:
...@@ -284,14 +324,12 @@ class Tester(TransformsTester): ...@@ -284,14 +324,12 @@ class Tester(TransformsTester):
) )
s_transform = torch.jit.script(transform) s_transform = torch.jit.script(transform)
torch.manual_seed(12) self._test_transform_vs_scripted(transform, s_transform, tensor)
out1 = transform(tensor) self._test_transform_vs_scripted_on_batch(transform, s_transform, batch_tensors)
torch.manual_seed(12)
out2 = s_transform(tensor)
self.assertTrue(out1.equal(out2))
def test_random_rotate(self): def test_random_rotate(self):
tensor = torch.randint(0, 255, size=(3, 44, 56), dtype=torch.uint8, device=self.device) tensor = torch.randint(0, 255, size=(3, 44, 56), dtype=torch.uint8, device=self.device)
batch_tensors = torch.randint(0, 255, size=(4, 3, 44, 56), dtype=torch.uint8, device=self.device)
for center in [(0, 0), [10, 10], None, (56, 44)]: for center in [(0, 0), [10, 10], None, (56, 44)]:
for expand in [True, False]: for expand in [True, False]:
...@@ -302,14 +340,12 @@ class Tester(TransformsTester): ...@@ -302,14 +340,12 @@ class Tester(TransformsTester):
) )
s_transform = torch.jit.script(transform) s_transform = torch.jit.script(transform)
torch.manual_seed(12) self._test_transform_vs_scripted(transform, s_transform, tensor)
out1 = transform(tensor) self._test_transform_vs_scripted_on_batch(transform, s_transform, batch_tensors)
torch.manual_seed(12)
out2 = s_transform(tensor)
self.assertTrue(out1.equal(out2))
def test_random_perspective(self): def test_random_perspective(self):
tensor = torch.randint(0, 255, size=(3, 44, 56), dtype=torch.uint8, device=self.device) tensor = torch.randint(0, 255, size=(3, 44, 56), dtype=torch.uint8, device=self.device)
batch_tensors = torch.randint(0, 255, size=(4, 3, 44, 56), dtype=torch.uint8, device=self.device)
for distortion_scale in np.linspace(0.1, 1.0, num=20): for distortion_scale in np.linspace(0.1, 1.0, num=20):
for interpolation in [NEAREST, BILINEAR]: for interpolation in [NEAREST, BILINEAR]:
...@@ -319,11 +355,8 @@ class Tester(TransformsTester): ...@@ -319,11 +355,8 @@ class Tester(TransformsTester):
) )
s_transform = torch.jit.script(transform) s_transform = torch.jit.script(transform)
torch.manual_seed(12) self._test_transform_vs_scripted(transform, s_transform, tensor)
out1 = transform(tensor) self._test_transform_vs_scripted_on_batch(transform, s_transform, batch_tensors)
torch.manual_seed(12)
out2 = s_transform(tensor)
self.assertTrue(out1.equal(out2))
def test_to_grayscale(self): def test_to_grayscale(self):
......
...@@ -36,7 +36,7 @@ def vflip(img: Tensor) -> Tensor: ...@@ -36,7 +36,7 @@ def vflip(img: Tensor) -> Tensor:
Please, consider instead using methods from `transforms.functional` module. Please, consider instead using methods from `transforms.functional` module.
Args: Args:
img (Tensor): Image Tensor to be flipped in the form [C, H, W]. img (Tensor): Image Tensor to be flipped in the form [..., C, H, W].
Returns: Returns:
Tensor: Vertically flipped image Tensor. Tensor: Vertically flipped image Tensor.
...@@ -56,7 +56,7 @@ def hflip(img: Tensor) -> Tensor: ...@@ -56,7 +56,7 @@ def hflip(img: Tensor) -> Tensor:
Please, consider instead using methods from `transforms.functional` module. Please, consider instead using methods from `transforms.functional` module.
Args: Args:
img (Tensor): Image Tensor to be flipped in the form [C, H, W]. img (Tensor): Image Tensor to be flipped in the form [..., C, H, W].
Returns: Returns:
Tensor: Horizontally flipped image Tensor. Tensor: Horizontally flipped image Tensor.
...@@ -183,7 +183,8 @@ def adjust_contrast(img: Tensor, contrast_factor: float) -> Tensor: ...@@ -183,7 +183,8 @@ def adjust_contrast(img: Tensor, contrast_factor: float) -> Tensor:
if not _is_tensor_a_torch_image(img): if not _is_tensor_a_torch_image(img):
raise TypeError('tensor is not a torch image.') raise TypeError('tensor is not a torch image.')
mean = torch.mean(rgb_to_grayscale(img).to(torch.float)) dtype = img.dtype if torch.is_floating_point(img) else torch.float32
mean = torch.mean(rgb_to_grayscale(img).to(dtype), dim=(-3, -2, -1), keepdim=True)
return _blend(img, mean, contrast_factor) return _blend(img, mean, contrast_factor)
...@@ -229,9 +230,9 @@ def adjust_hue(img: Tensor, hue_factor: float) -> Tensor: ...@@ -229,9 +230,9 @@ def adjust_hue(img: Tensor, hue_factor: float) -> Tensor:
img = img.to(dtype=torch.float32) / 255.0 img = img.to(dtype=torch.float32) / 255.0
img = _rgb2hsv(img) img = _rgb2hsv(img)
h, s, v = img.unbind(0) h, s, v = img.unbind(dim=-3)
h = (h + hue_factor) % 1.0 h = (h + hue_factor) % 1.0
img = torch.stack((h, s, v)) img = torch.stack((h, s, v), dim=-3)
img_hue_adj = _hsv2rgb(img) img_hue_adj = _hsv2rgb(img)
if orig_dtype == torch.uint8: if orig_dtype == torch.uint8:
...@@ -466,12 +467,12 @@ def _blend(img1: Tensor, img2: Tensor, ratio: float) -> Tensor: ...@@ -466,12 +467,12 @@ def _blend(img1: Tensor, img2: Tensor, ratio: float) -> Tensor:
def _rgb2hsv(img): def _rgb2hsv(img):
r, g, b = img.unbind(0) r, g, b = img.unbind(dim=-3)
# Implementation is based on https://github.com/python-pillow/Pillow/blob/4174d4267616897df3746d315d5a2d0f82c656ee/ # Implementation is based on https://github.com/python-pillow/Pillow/blob/4174d4267616897df3746d315d5a2d0f82c656ee/
# src/libImaging/Convert.c#L330 # src/libImaging/Convert.c#L330
maxc = torch.max(img, dim=0).values maxc = torch.max(img, dim=-3).values
minc = torch.min(img, dim=0).values minc = torch.min(img, dim=-3).values
# The algorithm erases S and H channel where `maxc = minc`. This avoids NaN # The algorithm erases S and H channel where `maxc = minc`. This avoids NaN
# from happening in the results, because # from happening in the results, because
...@@ -501,11 +502,11 @@ def _rgb2hsv(img): ...@@ -501,11 +502,11 @@ def _rgb2hsv(img):
hb = ((maxc != g) & (maxc != r)) * (4.0 + gc - rc) hb = ((maxc != g) & (maxc != r)) * (4.0 + gc - rc)
h = (hr + hg + hb) h = (hr + hg + hb)
h = torch.fmod((h / 6.0 + 1.0), 1.0) h = torch.fmod((h / 6.0 + 1.0), 1.0)
return torch.stack((h, s, maxc)) return torch.stack((h, s, maxc), dim=-3)
def _hsv2rgb(img): def _hsv2rgb(img):
h, s, v = img.unbind(0) h, s, v = img.unbind(dim=-3)
i = torch.floor(h * 6.0) i = torch.floor(h * 6.0)
f = (h * 6.0) - i f = (h * 6.0) - i
i = i.to(dtype=torch.int32) i = i.to(dtype=torch.int32)
...@@ -515,14 +516,14 @@ def _hsv2rgb(img): ...@@ -515,14 +516,14 @@ def _hsv2rgb(img):
t = torch.clamp((v * (1.0 - s * (1.0 - f))), 0.0, 1.0) t = torch.clamp((v * (1.0 - s * (1.0 - f))), 0.0, 1.0)
i = i % 6 i = i % 6
mask = i == torch.arange(6, device=i.device)[:, None, None] mask = i.unsqueeze(dim=-3) == torch.arange(6, device=i.device).view(-1, 1, 1)
a1 = torch.stack((v, q, p, p, t, v)) a1 = torch.stack((v, q, p, p, t, v), dim=-3)
a2 = torch.stack((t, v, v, q, p, p)) a2 = torch.stack((t, v, v, q, p, p), dim=-3)
a3 = torch.stack((p, p, t, v, v, q)) a3 = torch.stack((p, p, t, v, v, q), dim=-3)
a4 = torch.stack((a1, a2, a3)) a4 = torch.stack((a1, a2, a3), dim=-4)
return torch.einsum("ijk, xijk -> xjk", mask.to(dtype=img.dtype), a4) return torch.einsum("...ijk, ...xijk -> ...xjk", mask.to(dtype=img.dtype), a4)
def _pad_symmetric(img: Tensor, padding: List[int]) -> Tensor: def _pad_symmetric(img: Tensor, padding: List[int]) -> Tensor:
...@@ -793,6 +794,9 @@ def _apply_grid_transform(img: Tensor, grid: Tensor, mode: str) -> Tensor: ...@@ -793,6 +794,9 @@ def _apply_grid_transform(img: Tensor, grid: Tensor, mode: str) -> Tensor:
need_cast = True need_cast = True
img = img.to(grid) img = img.to(grid)
if img.shape[0] > 1:
# Apply same grid to a batch of images
grid = grid.expand(img.shape[0], grid.shape[1], grid.shape[2], grid.shape[3])
img = grid_sample(img, grid, mode=mode, padding_mode="zeros", align_corners=False) img = grid_sample(img, grid, mode=mode, padding_mode="zeros", align_corners=False)
if need_squeeze: if need_squeeze:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment