Unverified Commit 1921613a authored by Vasilis Vryniotis's avatar Vasilis Vryniotis Committed by GitHub
Browse files

[prototype] Gaussian Blur clean up (#6888)

* Refactor gaussian_blur

* Add conditional reshape

* Further refactoring

* Remove unused import.
parent c4c0ef98
...@@ -5,7 +5,6 @@ import PIL.Image ...@@ -5,7 +5,6 @@ import PIL.Image
import torch import torch
from torch.nn.functional import conv2d, pad as torch_pad from torch.nn.functional import conv2d, pad as torch_pad
from torchvision.prototype import features from torchvision.prototype import features
from torchvision.transforms import functional_tensor as _FT
from torchvision.transforms.functional import pil_to_tensor, to_pil_image from torchvision.transforms.functional import pil_to_tensor, to_pil_image
...@@ -68,9 +67,9 @@ def normalize( ...@@ -68,9 +67,9 @@ def normalize(
def _get_gaussian_kernel1d(kernel_size: int, sigma: float, dtype: torch.dtype, device: torch.device) -> torch.Tensor: def _get_gaussian_kernel1d(kernel_size: int, sigma: float, dtype: torch.dtype, device: torch.device) -> torch.Tensor:
lim = (kernel_size - 1) / (2 * math.sqrt(2) * sigma) lim = (kernel_size - 1) / (2.0 * math.sqrt(2.0) * sigma)
x = torch.linspace(-lim, lim, steps=kernel_size, dtype=dtype, device=device) x = torch.linspace(-lim, lim, steps=kernel_size, dtype=dtype, device=device)
kernel1d = torch.softmax(-x.pow_(2), dim=0) kernel1d = torch.softmax(x.pow_(2).neg_(), dim=0)
return kernel1d return kernel1d
...@@ -89,7 +88,7 @@ def gaussian_blur_image_tensor( ...@@ -89,7 +88,7 @@ def gaussian_blur_image_tensor(
# TODO: consider deprecating integers from sigma on the future # TODO: consider deprecating integers from sigma on the future
if isinstance(kernel_size, int): if isinstance(kernel_size, int):
kernel_size = [kernel_size, kernel_size] kernel_size = [kernel_size, kernel_size]
if len(kernel_size) != 2: elif len(kernel_size) != 2:
raise ValueError(f"If kernel_size is a sequence its length should be 2. Got {len(kernel_size)}") raise ValueError(f"If kernel_size is a sequence its length should be 2. Got {len(kernel_size)}")
for ksize in kernel_size: for ksize in kernel_size:
if ksize % 2 == 0 or ksize < 0: if ksize % 2 == 0 or ksize < 0:
...@@ -97,15 +96,19 @@ def gaussian_blur_image_tensor( ...@@ -97,15 +96,19 @@ def gaussian_blur_image_tensor(
if sigma is None: if sigma is None:
sigma = [ksize * 0.15 + 0.35 for ksize in kernel_size] sigma = [ksize * 0.15 + 0.35 for ksize in kernel_size]
else:
if sigma is not None and not isinstance(sigma, (int, float, list, tuple)): if isinstance(sigma, (list, tuple)):
length = len(sigma)
if length == 1:
s = float(sigma[0])
sigma = [s, s]
elif length != 2:
raise ValueError(f"If sigma is a sequence, its length should be 2. Got {length}")
elif isinstance(sigma, (int, float)):
s = float(sigma)
sigma = [s, s]
else:
raise TypeError(f"sigma should be either float or sequence of floats. Got {type(sigma)}") raise TypeError(f"sigma should be either float or sequence of floats. Got {type(sigma)}")
if isinstance(sigma, (int, float)):
sigma = [float(sigma), float(sigma)]
if isinstance(sigma, (list, tuple)) and len(sigma) == 1:
sigma = [sigma[0], sigma[0]]
if len(sigma) != 2:
raise ValueError(f"If sigma is a sequence, its length should be 2. Got {len(sigma)}")
for s in sigma: for s in sigma:
if s <= 0.0: if s <= 0.0:
raise ValueError(f"sigma should have positive values. Got {sigma}") raise ValueError(f"sigma should have positive values. Got {sigma}")
...@@ -113,30 +116,33 @@ def gaussian_blur_image_tensor( ...@@ -113,30 +116,33 @@ def gaussian_blur_image_tensor(
if image.numel() == 0: if image.numel() == 0:
return image return image
dtype = image.dtype
shape = image.shape shape = image.shape
ndim = image.ndim
if image.ndim > 4: if ndim == 3:
image = image.unsqueeze(dim=0)
elif ndim > 4:
image = image.reshape((-1,) + shape[-3:]) image = image.reshape((-1,) + shape[-3:])
needs_unsquash = True
else:
needs_unsquash = False
dtype = image.dtype if torch.is_floating_point(image) else torch.float32 fp = torch.is_floating_point(image)
kernel = _get_gaussian_kernel2d(kernel_size, sigma, dtype=dtype, device=image.device) kernel = _get_gaussian_kernel2d(kernel_size, sigma, dtype=dtype if fp else torch.float32, device=image.device)
kernel = kernel.expand(image.shape[-3], 1, kernel.shape[0], kernel.shape[1]) kernel = kernel.expand(shape[-3], 1, kernel.shape[0], kernel.shape[1])
image, need_cast, need_squeeze, out_dtype = _FT._cast_squeeze_in(image, [kernel.dtype]) output = image if fp else image.to(dtype=torch.float32)
# padding = (left, right, top, bottom) # padding = (left, right, top, bottom)
padding = [kernel_size[0] // 2, kernel_size[0] // 2, kernel_size[1] // 2, kernel_size[1] // 2] padding = [kernel_size[0] // 2, kernel_size[0] // 2, kernel_size[1] // 2, kernel_size[1] // 2]
output = torch_pad(image, padding, mode="reflect") output = torch_pad(output, padding, mode="reflect")
output = conv2d(output, kernel, groups=output.shape[-3]) output = conv2d(output, kernel, groups=shape[-3])
output = _FT._cast_squeeze_out(output, need_cast, need_squeeze, out_dtype) if ndim == 3:
output = output.squeeze(dim=0)
if needs_unsquash: elif ndim > 4:
output = output.reshape(shape) output = output.reshape(shape)
if not fp:
output = output.round_().to(dtype=dtype)
return output return output
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment