Unverified Commit 9430be76 authored by Lenz's avatar Lenz Committed by GitHub
Browse files

Added elastic transform in torchvision.transforms (#4938)



* Added elastic augment

* ufmt formatting

* updated comments

* fixed circular dependency issue and bare except error

* Fixed three type checking errors in functional_tensor.py

* ufmt formatted

* changed elastic_deformation to a more common implementation

Implementation uses alpha and sigma to control strength and smoothness of the displacement vectors in elastic_deformation instead of control_point_spacings and sigma.

* ufmt formatting

* Some performance updates

Put random offset vectors to device before gaussian_blur is applied speeds it up 3-fold.

* fixed type error

* fixed again a type error

* Update torchvision/transforms/functional_tensor.py
Co-authored-by: default avatarvfdev <vfdev.5@gmail.com>

* Added some requested changes

- pil image support similar to GaussianBlur
- changed interpolation arg to InterpolationMode
- added a wrapper in torchvision.transforms.functional.py that gets called by the class in transforms.py
-renamed it to ElasticTransform
- handled sigma = 0 case

* added img docstring

* added some tests

* Updated tests and the code

* Added the requested changes to the arguments of F.elastic_transform

Added random_state and displacement as arguments to F.elastic_transform

* fixed the type error

* Fixed tests and docs

* implemented requested changes

Changes:
1) alpha AND sigma OR displacement must be given as arguments to transforms.functional_tensor.elastic_transform instead of alpha AND sigma AND displacement
2) displacements are accepted in transforms.functional.elastic_transform as np.array and torch.Tensor instead of only accepting torch.Tensor

* ufmt formatting

* trochscript error resolved

replaced torch.from_numpy() to torch.Tensor() to make it compatible to torchscript

* revert to torch.from_numpy()

* updated argument checks and errors

- In F.elastic_transform added check to see if both user inputs img and displacement are either of type PIL Image and ndarray or both of type tensor.
- In F_t.elastic_transform added check if alpha and sigma are None if displacement is given or vice versa.

* fixed seed error

changed torch.seed to torch.manual_seed in F_t.elastic_transform

* Reverted displacement type and other cosmetics

* Other minor improvements

* changed gaussian_blur filter size

changed gaussian_blur filter size
from
4 * int(sigma) + 1
to
int(8 * sigma + 1)
to make it consistent with ernestums implementation

* resolved merge error

* Revert "resolved merge error"

This reverts commit 6a4a4e74ff4d078e2c2753d359185f9a81c415d0.

* resolve merge error

* ufmt formatted

* ufmt formated once again..

* fixed unsupported operand error

* Update API and removed random_state from functional part

* Added default values

* Added ElasticTransform to gallery and updated the docstring

* Updated gallery and added _log_api_usage_once
BTW, matplotlib.pylab is deprecated

* Updated gallery transforms code

* Updates according to review
Co-authored-by: default avatarvfdev <vfdev.5@gmail.com>
parent 8e0f6916
...@@ -149,6 +149,17 @@ affine_transfomer = T.RandomAffine(degrees=(30, 70), translate=(0.1, 0.3), scale ...@@ -149,6 +149,17 @@ affine_transfomer = T.RandomAffine(degrees=(30, 70), translate=(0.1, 0.3), scale
affine_imgs = [affine_transfomer(orig_img) for _ in range(4)] affine_imgs = [affine_transfomer(orig_img) for _ in range(4)]
plot(affine_imgs) plot(affine_imgs)
####################################
# ElasticTransform
# ~~~~~~~~~~~~~~~~
# The :class:`~torchvision.transforms.ElasticTransform` transform
# (see also :func:`~torchvision.transforms.functional.elastic_transform`)
# Randomly transforms the morphology of objects in images and produces a
# see-through-water-like effect.
elastic_transformer = T.ElasticTransform(alpha=250.0)
transformed_imgs = [elastic_transformer(orig_img) for _ in range(2)]
plot(transformed_imgs)
#################################### ####################################
# RandomCrop # RandomCrop
# ~~~~~~~~~~ # ~~~~~~~~~~
......
...@@ -325,7 +325,7 @@ print(data) ...@@ -325,7 +325,7 @@ print(data)
# ---------------------------------- # ----------------------------------
# Example of visualized video # Example of visualized video
import matplotlib.pylab as plt import matplotlib.pyplot as plt
plt.figure(figsize=(12, 12)) plt.figure(figsize=(12, 12))
for i in range(16): for i in range(16):
......
...@@ -1363,5 +1363,44 @@ def test_ten_crop(device): ...@@ -1363,5 +1363,44 @@ def test_ten_crop(device):
assert_equal(transformed_batch, s_transformed_batch) assert_equal(transformed_batch, s_transformed_batch)
@pytest.mark.parametrize("device", cpu_and_gpu())
@pytest.mark.parametrize("interpolation", [NEAREST, BILINEAR, BICUBIC])
@pytest.mark.parametrize("dt", [None, torch.float32, torch.float64, torch.float16])
@pytest.mark.parametrize(
"fill",
[
None,
[255, 255, 255],
(2.0,),
],
)
def test_elastic_transform_consistency(device, interpolation, dt, fill):
script_elastic_transform = torch.jit.script(F.elastic_transform)
img_tensor, _ = _create_data(32, 34, device=device)
# As there is no PIL implementation for elastic_transform,
# thus we do not run tests tensor vs pillow
if dt is not None:
img_tensor = img_tensor.to(dt)
displacement = T.ElasticTransform.get_params([1.5, 1.5], [2.0, 2.0], [32, 34])
kwargs = dict(
displacement=displacement,
interpolation=interpolation,
fill=fill,
)
out_tensor1 = F.elastic_transform(img_tensor, **kwargs)
out_tensor2 = script_elastic_transform(img_tensor, **kwargs)
assert_equal(out_tensor1, out_tensor2)
batch_tensors = _create_data_batch(16, 18, num_samples=4, device=device)
displacement = T.ElasticTransform.get_params([1.5, 1.5], [2.0, 2.0], [16, 18])
kwargs["displacement"] = displacement
if dt is not None:
batch_tensors = batch_tensors.to(dt)
_test_fn_on_batch(batch_tensors, F.elastic_transform, **kwargs)
if __name__ == "__main__": if __name__ == "__main__":
pytest.main([__file__]) pytest.main([__file__])
...@@ -2250,5 +2250,42 @@ def test_random_affine(): ...@@ -2250,5 +2250,42 @@ def test_random_affine():
assert t.interpolation == transforms.InterpolationMode.BILINEAR assert t.interpolation == transforms.InterpolationMode.BILINEAR
def test_elastic_transformation():
with pytest.raises(TypeError, match=r"alpha should be float or a sequence of floats"):
transforms.ElasticTransform(alpha=True, sigma=2.0)
with pytest.raises(TypeError, match=r"alpha should be a sequence of floats"):
transforms.ElasticTransform(alpha=[1.0, True], sigma=2.0)
with pytest.raises(ValueError, match=r"alpha is a sequence its length should be 2"):
transforms.ElasticTransform(alpha=[1.0, 0.0, 1.0], sigma=2.0)
with pytest.raises(TypeError, match=r"sigma should be float or a sequence of floats"):
transforms.ElasticTransform(alpha=2.0, sigma=True)
with pytest.raises(TypeError, match=r"sigma should be a sequence of floats"):
transforms.ElasticTransform(alpha=2.0, sigma=[1.0, True])
with pytest.raises(ValueError, match=r"sigma is a sequence its length should be 2"):
transforms.ElasticTransform(alpha=2.0, sigma=[1.0, 0.0, 1.0])
with pytest.warns(UserWarning, match=r"Argument interpolation should be of type InterpolationMode"):
t = transforms.transforms.ElasticTransform(alpha=2.0, sigma=2.0, interpolation=2)
assert t.interpolation == transforms.InterpolationMode.BILINEAR
with pytest.raises(TypeError, match=r"fill should be int or float"):
transforms.ElasticTransform(alpha=1.0, sigma=1.0, fill={})
x = torch.randint(0, 256, (3, 32, 32), dtype=torch.uint8)
img = F.to_pil_image(x)
t = transforms.ElasticTransform(alpha=0.0, sigma=0.0)
transformed_img = t(img)
assert transformed_img == img
# Smoke test on PIL images
t = transforms.ElasticTransform(alpha=0.5, sigma=0.23)
transformed_img = t(img)
assert isinstance(transformed_img, Image.Image)
# Checking if ElasticTransform can be printed as string
t.__repr__()
if __name__ == "__main__": if __name__ == "__main__":
pytest.main([__file__]) pytest.main([__file__])
...@@ -1484,3 +1484,67 @@ def equalize(img: Tensor) -> Tensor: ...@@ -1484,3 +1484,67 @@ def equalize(img: Tensor) -> Tensor:
return F_pil.equalize(img) return F_pil.equalize(img)
return F_t.equalize(img) return F_t.equalize(img)
def elastic_transform(
img: Tensor,
displacement: Tensor,
interpolation: InterpolationMode = InterpolationMode.BILINEAR,
fill: Optional[List[float]] = None,
) -> Tensor:
"""Transform a tensor image with elastic transformations.
Given alpha and sigma, it will generate displacement
vectors for all pixels based on random offsets. Alpha controls the strength
and sigma controls the smoothness of the displacements.
The displacements are added to an identity grid and the resulting grid is
used to grid_sample from the image.
Applications:
Randomly transforms the morphology of objects in images and produces a
see-through-water-like effect.
Args:
img (PIL Image or Tensor): Image on which elastic_transform is applied.
If img is torch Tensor, it is expected to be in [..., 1 or 3, H, W] format,
where ... means it can have an arbitrary number of leading dimensions.
If img is PIL Image, it is expected to be in mode "P", "L" or "RGB".
displacement (Tensor): The displacement field.
interpolation (InterpolationMode): Desired interpolation enum defined by
:class:`torchvision.transforms.InterpolationMode`.
Default is ``InterpolationMode.BILINEAR``.
For backward compatibility integer values (e.g. ``PIL.Image.NEAREST``) are still acceptable.
fill (number or str or tuple): Pixel fill value for constant fill. Default is 0.
If a tuple of length 3, it is used to fill R, G, B channels respectively.
This value is only used when the padding_mode is constant.
Only number is supported for torch Tensor.
Only int or str or tuple value is supported for PIL Image.
"""
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
_log_api_usage_once(elastic_transform)
# Backward compatibility with integer value
if isinstance(interpolation, int):
warnings.warn(
"Argument interpolation should be of type InterpolationMode instead of int. "
"Please, use InterpolationMode enum."
)
interpolation = _interpolation_modes_from_int(interpolation)
if not isinstance(displacement, torch.Tensor):
raise TypeError("displacement should be a Tensor")
t_img = img
if not isinstance(img, torch.Tensor):
if not F_pil._is_pil_image(img):
raise TypeError(f"img should be PIL Image or Tensor. Got {type(img)}")
t_img = pil_to_tensor(img)
output = F_t.elastic_transform(
t_img,
displacement,
interpolation=interpolation.value,
fill=fill,
)
if not isinstance(img, torch.Tensor):
output = to_pil_image(output, mode=img.mode)
return output
...@@ -968,3 +968,23 @@ def erase(img: Tensor, i: int, j: int, h: int, w: int, v: Tensor, inplace: bool ...@@ -968,3 +968,23 @@ def erase(img: Tensor, i: int, j: int, h: int, w: int, v: Tensor, inplace: bool
img[..., i : i + h, j : j + w] = v img[..., i : i + h, j : j + w] = v
return img return img
def elastic_transform(
img: Tensor,
displacement: Tensor,
interpolation: str = "bilinear",
fill: Optional[List[float]] = None,
) -> Tensor:
if not (isinstance(img, torch.Tensor)):
raise TypeError(f"img should be Tensor. Got {type(img)}")
size = list(img.shape[-2:])
displacement = displacement.to(img.device)
hw_space = [torch.linspace((-s + 1) / s, (s - 1) / s, s) for s in size]
grid_y, grid_x = torch.meshgrid(hw_space, indexing="ij")
identity_grid = torch.stack([grid_x, grid_y], -1).unsqueeze(0) # 1 x H x W x 2
grid = identity_grid.to(img.device) + displacement
return _apply_grid_transform(img, grid, interpolation, fill)
...@@ -53,6 +53,7 @@ __all__ = [ ...@@ -53,6 +53,7 @@ __all__ = [
"RandomAdjustSharpness", "RandomAdjustSharpness",
"RandomAutocontrast", "RandomAutocontrast",
"RandomEqualize", "RandomEqualize",
"ElasticTransform",
] ]
...@@ -2049,3 +2050,117 @@ class RandomEqualize(torch.nn.Module): ...@@ -2049,3 +2050,117 @@ class RandomEqualize(torch.nn.Module):
def __repr__(self) -> str: def __repr__(self) -> str:
return f"{self.__class__.__name__}(p={self.p})" return f"{self.__class__.__name__}(p={self.p})"
class ElasticTransform(torch.nn.Module):
"""Transform a tensor image with elastic transformations.
Given alpha and sigma, it will generate displacement
vectors for all pixels based on random offsets. Alpha controls the strength
and sigma controls the smoothness of the displacements.
The displacements are added to an identity grid and the resulting grid is
used to grid_sample from the image.
Applications:
Randomly transforms the morphology of objects in images and produces a
see-through-water-like effect.
Args:
alpha (float or sequence of floats): Magnitude of displacements. Default is 50.0.
sigma (float or sequence of floats): Smoothness of displacements. Default is 5.0.
interpolation (InterpolationMode): Desired interpolation enum defined by
:class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.BILINEAR``.
If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` are supported.
For backward compatibility integer values (e.g. ``PIL.Image.NEAREST``) are still acceptable.
fill (sequence or number): Pixel fill value for the area outside the transformed
image. Default is ``0``. If given a number, the value is used for all bands respectively.
"""
def __init__(self, alpha=50.0, sigma=5.0, interpolation=InterpolationMode.BILINEAR, fill=0):
super().__init__()
_log_api_usage_once(self)
if not isinstance(alpha, (float, Sequence)):
raise TypeError(f"alpha should be float or a sequence of floats. Got {type(alpha)}")
if isinstance(alpha, Sequence) and len(alpha) != 2:
raise ValueError(f"If alpha is a sequence its length should be 2. Got {len(alpha)}")
if isinstance(alpha, Sequence):
for element in alpha:
if not isinstance(element, float):
raise TypeError(f"alpha should be a sequence of floats. Got {type(element)}")
if isinstance(alpha, float):
alpha = [float(alpha), float(alpha)]
if isinstance(alpha, (list, tuple)) and len(alpha) == 1:
alpha = [alpha[0], alpha[0]]
self.alpha = alpha
if not isinstance(sigma, (float, Sequence)):
raise TypeError(f"sigma should be float or a sequence of floats. Got {type(sigma)}")
if isinstance(sigma, Sequence) and len(sigma) != 2:
raise ValueError(f"If sigma is a sequence its length should be 2. Got {len(sigma)}")
if isinstance(sigma, Sequence):
for element in sigma:
if not isinstance(element, float):
raise TypeError(f"sigma should be a sequence of floats. Got {type(element)}")
if isinstance(sigma, float):
sigma = [float(sigma), float(sigma)]
if isinstance(sigma, (list, tuple)) and len(sigma) == 1:
sigma = [sigma[0], sigma[0]]
self.sigma = sigma
# Backward compatibility with integer value
if isinstance(interpolation, int):
warnings.warn(
"Argument interpolation should be of type InterpolationMode instead of int. "
"Please, use InterpolationMode enum."
)
interpolation = _interpolation_modes_from_int(interpolation)
self.interpolation = interpolation
if not isinstance(fill, (int, float)):
raise TypeError(f"fill should be int or float. Got {type(fill)}")
self.fill = fill
@staticmethod
def get_params(alpha: List[float], sigma: List[float], size: List[int]) -> Tensor:
dx = torch.rand([1, 1] + size) * 2 - 1
if sigma[0] > 0.0:
kx = int(8 * sigma[0] + 1)
# if kernel size is even we have to make it odd
if kx % 2 == 0:
kx += 1
dx = F.gaussian_blur(dx, [kx, kx], sigma)
dx = dx * alpha[0] / size[0]
dy = torch.rand([1, 1] + size) * 2 - 1
if sigma[1] > 0.0:
ky = int(8 * sigma[1] + 1)
# if kernel size is even we have to make it odd
if ky % 2 == 0:
ky += 1
dy = F.gaussian_blur(dy, [ky, ky], sigma)
dy = dy * alpha[1] / size[1]
return torch.concat([dx, dy], 1).permute([0, 2, 3, 1]) # 1 x H x W x 2
def forward(self, tensor: Tensor) -> Tensor:
"""
Args:
img (PIL Image or Tensor): Image to be transformed.
Returns:
PIL Image or Tensor: Transformed image.
"""
size = F.get_image_size(tensor)[::-1]
displacement = self.get_params(self.alpha, self.sigma, size)
return F.elastic_transform(tensor, displacement, self.interpolation, self.fill)
def __repr__(self):
format_string = self.__class__.__name__ + "(alpha="
format_string += str(self.alpha) + ")"
format_string += ", (sigma=" + str(self.sigma) + ")"
format_string += ", interpolation={self.interpolation}"
format_string += ", fill={self.fill})"
return format_string
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment