Unverified Commit 7a7ab7e7 authored by Vasilis Vryniotis's avatar Vasilis Vryniotis Committed by GitHub
Browse files

[prototype] Speed up `adjust_sharpness_image_tensor` (#6930)

* Speed up `adjust_sharpness_image_tensor`

* Add a comment
parent bf58902b
import torch import torch
from torch.nn.functional import conv2d
from torchvision.prototype import features from torchvision.prototype import features
from torchvision.transforms import functional_pil as _FP, functional_tensor as _FT from torchvision.transforms import functional_pil as _FP, functional_tensor as _FT
...@@ -111,6 +112,8 @@ def adjust_sharpness_image_tensor(image: torch.Tensor, sharpness_factor: float) ...@@ -111,6 +112,8 @@ def adjust_sharpness_image_tensor(image: torch.Tensor, sharpness_factor: float)
if image.numel() == 0 or height <= 2 or width <= 2: if image.numel() == 0 or height <= 2 or width <= 2:
return image return image
bound = _FT._max_value(image.dtype)
fp = image.is_floating_point()
shape = image.shape shape = image.shape
if image.ndim > 4: if image.ndim > 4:
...@@ -119,7 +122,30 @@ def adjust_sharpness_image_tensor(image: torch.Tensor, sharpness_factor: float) ...@@ -119,7 +122,30 @@ def adjust_sharpness_image_tensor(image: torch.Tensor, sharpness_factor: float)
else: else:
needs_unsquash = False needs_unsquash = False
output = _blend(image, _FT._blurred_degenerate_image(image), sharpness_factor) # The following is a normalized 3x3 kernel with 1s in the edges and a 5 in the middle.
kernel_dtype = image.dtype if fp else torch.float32
a, b = 1.0 / 13.0, 5.0 / 13.0
kernel = torch.tensor([[a, a, a], [a, b, a], [a, a, a]], dtype=kernel_dtype, device=image.device)
kernel = kernel.expand(num_channels, 1, 3, 3)
# We copy and cast at the same time to avoid modifications on the original data
output = image.to(dtype=kernel_dtype, copy=True)
blurred_degenerate = conv2d(output, kernel, groups=num_channels)
if not fp:
# it is better to round before cast
blurred_degenerate = blurred_degenerate.round_()
# Create a view on the underlying output while pointing at the same data. We do this to avoid indexing twice.
view = output[..., 1:-1, 1:-1]
# We speed up blending by minimizing flops and doing in-place. The 2 blend options are mathematically equivalent:
# x+(1-r)*(y-x) = x + (1-r)*y - (1-r)*x = x*r + y*(1-r)
view.add_(blurred_degenerate.sub_(view), alpha=(1.0 - sharpness_factor))
# The actual data of ouput have been modified by the above. We only need to clamp and cast now.
output = output.clamp_(0, bound)
if not fp:
output = output.to(image.dtype)
if needs_unsquash: if needs_unsquash:
output = output.reshape(shape) output = output.reshape(shape)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment