Unverified Commit 4106dbb8 authored by Tejan Karmali's avatar Tejan Karmali Committed by GitHub
Browse files

Added GaussianBlur transform (#2658)



* Added GaussianBlur transform

* fixed linting

* supports fixed radius for kernel

* [WIP] New API for gaussian_blur

* Gaussian blur with kernelsize and sigma API

* Fixed implementation and updated tests

* Added large input case and refactored gt into a file

* Updated docs

* fix kernel dimesnions order while creating kernel

* added tests for exception handling of gaussian blur

* fix linting, bug in tests

* Fixed failing tests, refactored code and other minor fixes
Co-authored-by: default avatarvfdev-5 <vfdev.5@gmail.com>
parent 87c78641
...@@ -81,6 +81,8 @@ Transforms on PIL Image ...@@ -81,6 +81,8 @@ Transforms on PIL Image
.. autoclass:: TenCrop .. autoclass:: TenCrop
.. autoclass:: GaussianBlur
Transforms on torch.\*Tensor Transforms on torch.\*Tensor
---------------------------- ----------------------------
......
import os
import unittest import unittest
import colorsys import colorsys
import math import math
...@@ -675,14 +676,14 @@ class Tester(TransformsTester): ...@@ -675,14 +676,14 @@ class Tester(TransformsTester):
batch_tensors, F.rotate, angle=32, resample=0, expand=True, center=center batch_tensors, F.rotate, angle=32, resample=0, expand=True, center=center
) )
def _test_perspective(self, tensor, pil_img, scripted_tranform, test_configs): def _test_perspective(self, tensor, pil_img, scripted_transform, test_configs):
dt = tensor.dtype dt = tensor.dtype
for r in [0, ]: for r in [0, ]:
for spoints, epoints in test_configs: for spoints, epoints in test_configs:
out_pil_img = F.perspective(pil_img, startpoints=spoints, endpoints=epoints, interpolation=r) out_pil_img = F.perspective(pil_img, startpoints=spoints, endpoints=epoints, interpolation=r)
out_pil_tensor = torch.from_numpy(np.array(out_pil_img).transpose((2, 0, 1))) out_pil_tensor = torch.from_numpy(np.array(out_pil_img).transpose((2, 0, 1)))
for fn in [F.perspective, scripted_tranform]: for fn in [F.perspective, scripted_transform]:
out_tensor = fn(tensor, startpoints=spoints, endpoints=epoints, interpolation=r).cpu() out_tensor = fn(tensor, startpoints=spoints, endpoints=epoints, interpolation=r).cpu()
if out_tensor.dtype != torch.uint8: if out_tensor.dtype != torch.uint8:
...@@ -707,7 +708,7 @@ class Tester(TransformsTester): ...@@ -707,7 +708,7 @@ class Tester(TransformsTester):
from torchvision.transforms import RandomPerspective from torchvision.transforms import RandomPerspective
data = [self._create_data(26, 34, device=self.device), self._create_data(26, 26, device=self.device)] data = [self._create_data(26, 34, device=self.device), self._create_data(26, 26, device=self.device)]
scripted_tranform = torch.jit.script(F.perspective) scripted_transform = torch.jit.script(F.perspective)
for tensor, pil_img in data: for tensor, pil_img in data:
...@@ -730,7 +731,7 @@ class Tester(TransformsTester): ...@@ -730,7 +731,7 @@ class Tester(TransformsTester):
if dt is not None: if dt is not None:
tensor = tensor.to(dtype=dt) tensor = tensor.to(dtype=dt)
self._test_perspective(tensor, pil_img, scripted_tranform, test_configs) self._test_perspective(tensor, pil_img, scripted_transform, test_configs)
batch_tensors = self._create_data_batch(26, 36, num_samples=4, device=self.device) batch_tensors = self._create_data_batch(26, 36, num_samples=4, device=self.device)
if dt is not None: if dt is not None:
...@@ -741,6 +742,70 @@ class Tester(TransformsTester): ...@@ -741,6 +742,70 @@ class Tester(TransformsTester):
batch_tensors, F.perspective, startpoints=spoints, endpoints=epoints, interpolation=0 batch_tensors, F.perspective, startpoints=spoints, endpoints=epoints, interpolation=0
) )
def test_gaussian_blur(self):
small_image_tensor = torch.from_numpy(
np.arange(3 * 10 * 12, dtype="uint8").reshape((10, 12, 3))
).permute(2, 0, 1).to(self.device)
large_image_tensor = torch.from_numpy(
np.arange(26 * 28, dtype="uint8").reshape((1, 26, 28))
).to(self.device)
scripted_transform = torch.jit.script(F.gaussian_blur)
# true_cv2_results = {
# # np_img = np.arange(3 * 10 * 12, dtype="uint8").reshape((10, 12, 3))
# # cv2.GaussianBlur(np_img, ksize=(3, 3), sigmaX=0.8)
# "3_3_0.8": ...
# # cv2.GaussianBlur(np_img, ksize=(3, 3), sigmaX=0.5)
# "3_3_0.5": ...
# # cv2.GaussianBlur(np_img, ksize=(3, 5), sigmaX=0.8)
# "3_5_0.8": ...
# # cv2.GaussianBlur(np_img, ksize=(3, 5), sigmaX=0.5)
# "3_5_0.5": ...
# # np_img2 = np.arange(26 * 28, dtype="uint8").reshape((26, 28))
# # cv2.GaussianBlur(np_img2, ksize=(23, 23), sigmaX=1.7)
# "23_23_1.7": ...
# }
p = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'assets', 'gaussian_blur_opencv_results.pt')
true_cv2_results = torch.load(p)
for tensor in [small_image_tensor, large_image_tensor]:
for dt in [None, torch.float32, torch.float64, torch.float16]:
if dt == torch.float16 and torch.device(self.device).type == "cpu":
# skip float16 on CPU case
continue
if dt is not None:
tensor = tensor.to(dtype=dt)
for ksize in [(3, 3), [3, 5], (23, 23)]:
for sigma in [[0.5, 0.5], (0.5, 0.5), (0.8, 0.8), (1.7, 1.7)]:
_ksize = (ksize, ksize) if isinstance(ksize, int) else ksize
_sigma = sigma[0] if sigma is not None else None
shape = tensor.shape
gt_key = "{}_{}_{}__{}_{}_{}".format(
shape[-2], shape[-1], shape[-3],
_ksize[0], _ksize[1], _sigma
)
if gt_key not in true_cv2_results:
continue
true_out = torch.tensor(
true_cv2_results[gt_key]
).reshape(shape[-2], shape[-1], shape[-3]).permute(2, 0, 1).to(tensor)
for fn in [F.gaussian_blur, scripted_transform]:
out = fn(tensor, kernel_size=ksize, sigma=sigma)
self.assertEqual(true_out.shape, out.shape, msg="{}, {}".format(ksize, sigma))
self.assertLessEqual(
torch.max(true_out.float() - out.float()),
1.0,
msg="{}, {}".format(ksize, sigma)
)
@unittest.skipIf(not torch.cuda.is_available(), reason="Skip if no CUDA device") @unittest.skipIf(not torch.cuda.is_available(), reason="Skip if no CUDA device")
class CUDATester(Tester): class CUDATester(Tester):
......
...@@ -1654,6 +1654,48 @@ class Tester(unittest.TestCase): ...@@ -1654,6 +1654,48 @@ class Tester(unittest.TestCase):
# Checking if RandomGrayscale can be printed as string # Checking if RandomGrayscale can be printed as string
trans3.__repr__() trans3.__repr__()
def test_gaussian_blur_asserts(self):
np_img = np.ones((100, 100, 3), dtype=np.uint8) * 255
img = F.to_pil_image(np_img, "RGB")
with self.assertRaisesRegex(ValueError, r"If kernel_size is a sequence its length should be 2"):
F.gaussian_blur(img, [3])
with self.assertRaisesRegex(ValueError, r"If kernel_size is a sequence its length should be 2"):
F.gaussian_blur(img, [3, 3, 3])
with self.assertRaisesRegex(ValueError, r"Kernel size should be a tuple/list of two integers"):
transforms.GaussianBlur([3, 3, 3])
with self.assertRaisesRegex(ValueError, r"kernel_size should have odd and positive integers"):
F.gaussian_blur(img, [4, 4])
with self.assertRaisesRegex(ValueError, r"Kernel size value should be an odd and positive number"):
transforms.GaussianBlur([4, 4])
with self.assertRaisesRegex(ValueError, r"kernel_size should have odd and positive integers"):
F.gaussian_blur(img, [-3, -3])
with self.assertRaisesRegex(ValueError, r"Kernel size value should be an odd and positive number"):
transforms.GaussianBlur([-3, -3])
with self.assertRaisesRegex(ValueError, r"If sigma is a sequence, its length should be 2"):
F.gaussian_blur(img, 3, [1, 1, 1])
with self.assertRaisesRegex(ValueError, r"sigma should be a single number or a list/tuple with length 2"):
transforms.GaussianBlur(3, [1, 1, 1])
with self.assertRaisesRegex(ValueError, r"sigma should have positive values"):
F.gaussian_blur(img, 3, -1.0)
with self.assertRaisesRegex(ValueError, r"If sigma is a single number, it must be positive"):
transforms.GaussianBlur(3, -1.0)
with self.assertRaisesRegex(TypeError, r"kernel_size should be int or a sequence of integers"):
F.gaussian_blur(img, "kernel_size_string")
with self.assertRaisesRegex(ValueError, r"Kernel size should be a tuple/list of two integers"):
transforms.GaussianBlur("kernel_size_string")
with self.assertRaisesRegex(TypeError, r"sigma should be either float or sequence of floats"):
F.gaussian_blur(img, 3, "sigma_string")
with self.assertRaisesRegex(ValueError, r"sigma should be a single number or a list/tuple with length 2"):
transforms.GaussianBlur(3, "sigma_string")
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
...@@ -466,6 +466,38 @@ class Tester(TransformsTester): ...@@ -466,6 +466,38 @@ class Tester(TransformsTester):
with self.assertRaisesRegex(RuntimeError, r"Could not get name of python class object"): with self.assertRaisesRegex(RuntimeError, r"Could not get name of python class object"):
torch.jit.script(t) torch.jit.script(t)
def test_gaussian_blur(self):
tol = 1.0 + 1e-10
self._test_class_op(
"GaussianBlur", meth_kwargs={"kernel_size": 3, "sigma": 0.75},
test_exact_match=False, agg_method="max", tol=tol
)
self._test_class_op(
"GaussianBlur", meth_kwargs={"kernel_size": 23, "sigma": [0.1, 2.0]},
test_exact_match=False, agg_method="max", tol=tol
)
self._test_class_op(
"GaussianBlur", meth_kwargs={"kernel_size": 23, "sigma": (0.1, 2.0)},
test_exact_match=False, agg_method="max", tol=tol
)
self._test_class_op(
"GaussianBlur", meth_kwargs={"kernel_size": [3, 3], "sigma": (1.0, 1.0)},
test_exact_match=False, agg_method="max", tol=tol
)
self._test_class_op(
"GaussianBlur", meth_kwargs={"kernel_size": (3, 3), "sigma": (0.1, 2.0)},
test_exact_match=False, agg_method="max", tol=tol
)
self._test_class_op(
"GaussianBlur", meth_kwargs={"kernel_size": [23], "sigma": 0.75},
test_exact_match=False, agg_method="max", tol=tol
)
def test_random_erasing(self): def test_random_erasing(self):
img = torch.rand(3, 60, 60) img = torch.rand(3, 60, 60)
......
...@@ -115,7 +115,7 @@ def pil_to_tensor(pic): ...@@ -115,7 +115,7 @@ def pil_to_tensor(pic):
Returns: Returns:
Tensor: Converted image. Tensor: Converted image.
""" """
if not(F_pil._is_pil_image(pic)): if not F_pil._is_pil_image(pic):
raise TypeError('pic should be PIL Image. Got {}'.format(type(pic))) raise TypeError('pic should be PIL Image. Got {}'.format(type(pic)))
if accimage is not None and isinstance(pic, accimage.Image): if accimage is not None and isinstance(pic, accimage.Image):
...@@ -297,7 +297,7 @@ def resize(img: Tensor, size: List[int], interpolation: int = Image.BILINEAR) -> ...@@ -297,7 +297,7 @@ def resize(img: Tensor, size: List[int], interpolation: int = Image.BILINEAR) ->
the smaller edge of the image will be matched to this number maintaining the smaller edge of the image will be matched to this number maintaining
the aspect ratio. i.e, if height > width, then image will be rescaled to the aspect ratio. i.e, if height > width, then image will be rescaled to
:math:`\left(\text{size} \times \frac{\text{height}}{\text{width}}, \text{size}\right)`. :math:`\left(\text{size} \times \frac{\text{height}}{\text{width}}, \text{size}\right)`.
In torchscript mode padding as single int is not supported, use a tuple or In torchscript mode size as single int is not supported, use a tuple or
list of length 1: ``[size, ]``. list of length 1: ``[size, ]``.
interpolation (int, optional): Desired interpolation enum defined by `filters`_. interpolation (int, optional): Desired interpolation enum defined by `filters`_.
Default is ``PIL.Image.BILINEAR``. If input is Tensor, only ``PIL.Image.NEAREST``, ``PIL.Image.BILINEAR`` Default is ``PIL.Image.BILINEAR``. If input is Tensor, only ``PIL.Image.NEAREST``, ``PIL.Image.BILINEAR``
...@@ -988,3 +988,63 @@ def erase(img: Tensor, i: int, j: int, h: int, w: int, v: Tensor, inplace: bool ...@@ -988,3 +988,63 @@ def erase(img: Tensor, i: int, j: int, h: int, w: int, v: Tensor, inplace: bool
img[..., i:i + h, j:j + w] = v img[..., i:i + h, j:j + w] = v
return img return img
def gaussian_blur(img: Tensor, kernel_size: List[int], sigma: Optional[List[float]] = None) -> Tensor:
"""Performs Gaussian blurring on the img by given kernel.
The image can be a PIL Image or a Tensor, in which case it is expected
to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions
Args:
img (PIL Image or Tensor): Image to be blurred
kernel_size (sequence of ints or int): Gaussian kernel size. Can be a sequence of integers
like ``(kx, ky)`` or a single integer for square kernels.
In torchscript mode kernel_size as single int is not supported, use a tuple or
list of length 1: ``[ksize, ]``.
sigma (sequence of floats or float, optional): Gaussian kernel standard deviation. Can be a
sequence of floats like ``(sigma_x, sigma_y)`` or a single float to define the
same sigma in both X/Y directions. If None, then it is computed using
``kernel_size`` as ``sigma = 0.3 * ((kernel_size - 1) * 0.5 - 1) + 0.8``.
Default, None. In torchscript mode sigma as single float is
not supported, use a tuple or list of length 1: ``[sigma, ]``.
Returns:
PIL Image or Tensor: Gaussian Blurred version of the image.
"""
if not isinstance(kernel_size, (int, list, tuple)):
raise TypeError('kernel_size should be int or a sequence of integers. Got {}'.format(type(kernel_size)))
if isinstance(kernel_size, int):
kernel_size = [kernel_size, kernel_size]
if len(kernel_size) != 2:
raise ValueError('If kernel_size is a sequence its length should be 2. Got {}'.format(len(kernel_size)))
for ksize in kernel_size:
if ksize % 2 == 0 or ksize < 0:
raise ValueError('kernel_size should have odd and positive integers. Got {}'.format(kernel_size))
if sigma is None:
sigma = [ksize * 0.15 + 0.35 for ksize in kernel_size]
if sigma is not None and not isinstance(sigma, (int, float, list, tuple)):
raise TypeError('sigma should be either float or sequence of floats. Got {}'.format(type(sigma)))
if isinstance(sigma, (int, float)):
sigma = [float(sigma), float(sigma)]
if isinstance(sigma, (list, tuple)) and len(sigma) == 1:
sigma = [sigma[0], sigma[0]]
if len(sigma) != 2:
raise ValueError('If sigma is a sequence, its length should be 2. Got {}'.format(len(sigma)))
for s in sigma:
if s <= 0.:
raise ValueError('sigma should have positive values. Got {}'.format(sigma))
t_img = img
if not isinstance(img, torch.Tensor):
if not F_pil._is_pil_image(img):
raise TypeError('img should be PIL Image or Tensor. Got {}'.format(type(img)))
t_img = to_tensor(img)
output = F_t.gaussian_blur(t_img, kernel_size, sigma)
if not isinstance(img, torch.Tensor):
output = to_pil_image(output)
return output
...@@ -3,7 +3,7 @@ from typing import Any, List, Sequence ...@@ -3,7 +3,7 @@ from typing import Any, List, Sequence
import numpy as np import numpy as np
import torch import torch
from PIL import Image, ImageOps, ImageEnhance, __version__ as PILLOW_VERSION from PIL import Image, ImageOps, ImageEnhance, ImageFilter, __version__ as PILLOW_VERSION
try: try:
import accimage import accimage
......
...@@ -3,7 +3,7 @@ from typing import Optional, Dict, Tuple ...@@ -3,7 +3,7 @@ from typing import Optional, Dict, Tuple
import torch import torch
from torch import Tensor from torch import Tensor
from torch.nn.functional import grid_sample from torch.nn.functional import grid_sample, conv2d, interpolate, pad as torch_pad
from torch.jit.annotations import List, BroadcastingList2 from torch.jit.annotations import List, BroadcastingList2
...@@ -746,7 +746,7 @@ def pad(img: Tensor, padding: List[int], fill: int = 0, padding_mode: str = "con ...@@ -746,7 +746,7 @@ def pad(img: Tensor, padding: List[int], fill: int = 0, padding_mode: str = "con
need_cast = True need_cast = True
img = img.to(torch.float32) img = img.to(torch.float32)
img = torch.nn.functional.pad(img, p, mode=padding_mode, value=float(fill)) img = torch_pad(img, p, mode=padding_mode, value=float(fill))
if need_squeeze: if need_squeeze:
img = img.squeeze(dim=0) img = img.squeeze(dim=0)
...@@ -839,7 +839,7 @@ def resize(img: Tensor, size: List[int], interpolation: int = 2) -> Tensor: ...@@ -839,7 +839,7 @@ def resize(img: Tensor, size: List[int], interpolation: int = 2) -> Tensor:
# Define align_corners to avoid warnings # Define align_corners to avoid warnings
align_corners = False if mode in ["bilinear", "bicubic"] else None align_corners = False if mode in ["bilinear", "bicubic"] else None
img = torch.nn.functional.interpolate(img, size=(size_h, size_w), mode=mode, align_corners=align_corners) img = interpolate(img, size=[size_h, size_w], mode=mode, align_corners=align_corners)
if need_squeeze: if need_squeeze:
img = img.squeeze(dim=0) img = img.squeeze(dim=0)
...@@ -879,24 +879,22 @@ def _assert_grid_transform_inputs( ...@@ -879,24 +879,22 @@ def _assert_grid_transform_inputs(
raise ValueError("Resampling mode '{}' is unsupported with Tensor input".format(resample)) raise ValueError("Resampling mode '{}' is unsupported with Tensor input".format(resample))
def _apply_grid_transform(img: Tensor, grid: Tensor, mode: str) -> Tensor: def _cast_squeeze_in(img: Tensor, req_dtype: torch.dtype) -> Tuple[Tensor, bool, bool, torch.dtype]:
# make image NCHW
need_squeeze = False need_squeeze = False
# make image NCHW
if img.ndim < 4: if img.ndim < 4:
img = img.unsqueeze(dim=0) img = img.unsqueeze(dim=0)
need_squeeze = True need_squeeze = True
out_dtype = img.dtype out_dtype = img.dtype
need_cast = False need_cast = False
if out_dtype != grid.dtype: if out_dtype != req_dtype:
need_cast = True need_cast = True
img = img.to(grid) img = img.to(req_dtype)
return img, need_cast, need_squeeze, out_dtype
if img.shape[0] > 1:
# Apply same grid to a batch of images
grid = grid.expand(img.shape[0], grid.shape[1], grid.shape[2], grid.shape[3])
img = grid_sample(img, grid, mode=mode, padding_mode="zeros", align_corners=False)
def _cast_squeeze_out(img: Tensor, need_cast: bool, need_squeeze: bool, out_dtype: torch.dtype):
if need_squeeze: if need_squeeze:
img = img.squeeze(dim=0) img = img.squeeze(dim=0)
...@@ -907,6 +905,19 @@ def _apply_grid_transform(img: Tensor, grid: Tensor, mode: str) -> Tensor: ...@@ -907,6 +905,19 @@ def _apply_grid_transform(img: Tensor, grid: Tensor, mode: str) -> Tensor:
return img return img
def _apply_grid_transform(img: Tensor, grid: Tensor, mode: str) -> Tensor:
img, need_cast, need_squeeze, out_dtype = _cast_squeeze_in(img, grid.dtype)
if img.shape[0] > 1:
# Apply same grid to a batch of images
grid = grid.expand(img.shape[0], grid.shape[1], grid.shape[2], grid.shape[3])
img = grid_sample(img, grid, mode=mode, padding_mode="zeros", align_corners=False)
img = _cast_squeeze_out(img, need_cast, need_squeeze, out_dtype)
return img
def _gen_affine_grid( def _gen_affine_grid(
theta: Tensor, w: int, h: int, ow: int, oh: int, theta: Tensor, w: int, h: int, ow: int, oh: int,
) -> Tensor: ) -> Tensor:
...@@ -1109,3 +1120,56 @@ def perspective( ...@@ -1109,3 +1120,56 @@ def perspective(
mode = _interpolation_modes[interpolation] mode = _interpolation_modes[interpolation]
return _apply_grid_transform(img, grid, mode) return _apply_grid_transform(img, grid, mode)
def _get_gaussian_kernel1d(kernel_size: int, sigma: float) -> Tensor:
ksize_half = (kernel_size - 1) * 0.5
x = torch.linspace(-ksize_half, ksize_half, steps=kernel_size)
pdf = torch.exp(-0.5 * (x / sigma).pow(2))
kernel1d = pdf / pdf.sum()
return kernel1d
def _get_gaussian_kernel2d(
kernel_size: List[int], sigma: List[float], dtype: torch.dtype, device: torch.device
) -> Tensor:
kernel1d_x = _get_gaussian_kernel1d(kernel_size[0], sigma[0]).to(device, dtype=dtype)
kernel1d_y = _get_gaussian_kernel1d(kernel_size[1], sigma[1]).to(device, dtype=dtype)
kernel2d = torch.mm(kernel1d_y[:, None], kernel1d_x[None, :])
return kernel2d
def gaussian_blur(img: Tensor, kernel_size: List[int], sigma: List[float]) -> Tensor:
"""PRIVATE METHOD. Performs Gaussian blurring on the img by given kernel.
.. warning::
Module ``transforms.functional_tensor`` is private and should not be used in user application.
Please, consider instead using methods from `transforms.functional` module.
Args:
img (Tensor): Image to be blurred
kernel_size (sequence of int or int): Kernel size of the Gaussian kernel ``(kx, ky)``.
sigma (sequence of float or float, optional): Standard deviation of the Gaussian kernel ``(sx, sy)``.
Returns:
Tensor: An image that is blurred using gaussian kernel of given parameters
"""
if not (isinstance(img, torch.Tensor) or _is_tensor_a_torch_image(img)):
raise TypeError('img should be Tensor Image. Got {}'.format(type(img)))
dtype = img.dtype if torch.is_floating_point(img) else torch.float32
kernel = _get_gaussian_kernel2d(kernel_size, sigma, dtype=dtype, device=img.device)
kernel = kernel.expand(img.shape[-3], 1, kernel.shape[0], kernel.shape[1])
img, need_cast, need_squeeze, out_dtype = _cast_squeeze_in(img, kernel.dtype)
# padding = (left, right, top, bottom)
padding = [kernel_size[0] // 2, kernel_size[0] // 2, kernel_size[1] // 2, kernel_size[1] // 2]
img = torch_pad(img, padding, mode="reflect")
img = conv2d(img, kernel, groups=img.shape[-3])
img = _cast_squeeze_out(img, need_cast, need_squeeze, out_dtype)
return img
...@@ -3,7 +3,7 @@ import numbers ...@@ -3,7 +3,7 @@ import numbers
import random import random
import warnings import warnings
from collections.abc import Sequence from collections.abc import Sequence
from typing import Tuple, List, Optional, Any from typing import Tuple, List, Optional
import torch import torch
from PIL import Image from PIL import Image
...@@ -20,7 +20,7 @@ __all__ = ["Compose", "ToTensor", "PILToTensor", "ConvertImageDtype", "ToPILImag ...@@ -20,7 +20,7 @@ __all__ = ["Compose", "ToTensor", "PILToTensor", "ConvertImageDtype", "ToPILImag
"CenterCrop", "Pad", "Lambda", "RandomApply", "RandomChoice", "RandomOrder", "RandomCrop", "CenterCrop", "Pad", "Lambda", "RandomApply", "RandomChoice", "RandomOrder", "RandomCrop",
"RandomHorizontalFlip", "RandomVerticalFlip", "RandomResizedCrop", "RandomSizedCrop", "FiveCrop", "TenCrop", "RandomHorizontalFlip", "RandomVerticalFlip", "RandomResizedCrop", "RandomSizedCrop", "FiveCrop", "TenCrop",
"LinearTransformation", "ColorJitter", "RandomRotation", "RandomAffine", "Grayscale", "RandomGrayscale", "LinearTransformation", "ColorJitter", "RandomRotation", "RandomAffine", "Grayscale", "RandomGrayscale",
"RandomPerspective", "RandomErasing"] "RandomPerspective", "RandomErasing", "GaussianBlur"]
_pil_interpolation_to_str = { _pil_interpolation_to_str = {
Image.NEAREST: 'PIL.Image.NEAREST', Image.NEAREST: 'PIL.Image.NEAREST',
...@@ -1494,6 +1494,73 @@ class RandomErasing(torch.nn.Module): ...@@ -1494,6 +1494,73 @@ class RandomErasing(torch.nn.Module):
return img return img
class GaussianBlur(torch.nn.Module):
"""Blurs image with randomly chosen Gaussian blur.
The image can be a PIL Image or a Tensor, in which case it is expected
to have [..., 3, H, W] shape, where ... means an arbitrary number of leading
dimensions
Args:
kernel_size (int or sequence): Size of the Gaussian kernel.
sigma (float or tuple of float (min, max)): Standard deviation to be used for
creating kernel to perform blurring. If float, sigma is fixed. If it is tuple
of float (min, max), sigma is chosen uniformly at random to lie in the
given range.
Returns:
PIL Image or Tensor: Gaussian blurred version of the input image.
"""
def __init__(self, kernel_size, sigma=(0.1, 2.0)):
super().__init__()
self.kernel_size = _setup_size(kernel_size, "Kernel size should be a tuple/list of two integers")
for ks in self.kernel_size:
if ks <= 0 or ks % 2 == 0:
raise ValueError("Kernel size value should be an odd and positive number.")
if isinstance(sigma, numbers.Number):
if sigma <= 0:
raise ValueError("If sigma is a single number, it must be positive.")
sigma = (sigma, sigma)
elif isinstance(sigma, Sequence) and len(sigma) == 2:
if not 0. < sigma[0] <= sigma[1]:
raise ValueError("sigma values should be positive and of the form (min, max).")
else:
raise ValueError("sigma should be a single number or a list/tuple with length 2.")
self.sigma = sigma
@staticmethod
def get_params(sigma_min: float, sigma_max: float) -> float:
"""Choose sigma for ``gaussian_blur`` for random gaussian blurring.
Args:
sigma_min (float): Minimum standard deviation that can be chosen for blurring kernel.
sigma_max (float): Maximum standard deviation that can be chosen for blurring kernel.
Returns:
float: Standard deviation to be passed to calculate kernel for gaussian blurring.
"""
return torch.empty(1).uniform_(sigma_min, sigma_max).item()
def forward(self, img: Tensor) -> Tensor:
"""
Args:
img (PIL Image or Tensor): image of size (C, H, W) to be blurred.
Returns:
PIL Image or Tensor: Gaussian blurred image
"""
sigma = self.get_params(self.sigma[0], self.sigma[1])
return F.gaussian_blur(img, self.kernel_size, [sigma, sigma])
def __repr__(self):
s = '(kernel_size={}, '.format(self.kernel_size)
s += 'sigma={})'.format(self.sigma)
return self.__class__.__name__ + s
def _setup_size(size, error_msg): def _setup_size(size, error_msg):
if isinstance(size, numbers.Number): if isinstance(size, numbers.Number):
return int(size), int(size) return int(size), int(size)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment