"...git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "5934873b8f4c1dda00a6271bc40fd2a45a1a918e"
Unverified Commit decb1919 authored by vfdev's avatar vfdev Committed by GitHub
Browse files

[proto] Small optimization for gaussian_blur functional op (#6762)

* Use softmax in _get_gaussian_kernel1d

* Revert "Use softmax in _get_gaussian_kernel1d"

This reverts commit eb8fba36302d2da9e06e6f40afaaf901b276a771.

* Code update

* Relaxed tolerance in consistency tests for GaussianBlur and ElasticTransform

* Code review updates

* Update test_prototype_transforms_consistency.py
parent 149edda4
...@@ -308,22 +308,28 @@ CONSISTENCY_CONFIGS = [ ...@@ -308,22 +308,28 @@ CONSISTENCY_CONFIGS = [
ArgsKwargs(brightness=0.1, contrast=0.4, saturation=0.7, hue=0.3), ArgsKwargs(brightness=0.1, contrast=0.4, saturation=0.7, hue=0.3),
], ],
), ),
ConsistencyConfig( *[
prototype_transforms.ElasticTransform, ConsistencyConfig(
legacy_transforms.ElasticTransform, prototype_transforms.ElasticTransform,
[ legacy_transforms.ElasticTransform,
ArgsKwargs(), [
ArgsKwargs(alpha=20.0), ArgsKwargs(),
ArgsKwargs(alpha=(15.3, 27.2)), ArgsKwargs(alpha=20.0),
ArgsKwargs(sigma=3.0), ArgsKwargs(alpha=(15.3, 27.2)),
ArgsKwargs(sigma=(2.5, 3.9)), ArgsKwargs(sigma=3.0),
ArgsKwargs(interpolation=prototype_transforms.InterpolationMode.NEAREST), ArgsKwargs(sigma=(2.5, 3.9)),
ArgsKwargs(interpolation=prototype_transforms.InterpolationMode.BICUBIC), ArgsKwargs(interpolation=prototype_transforms.InterpolationMode.NEAREST),
ArgsKwargs(fill=1), ArgsKwargs(interpolation=prototype_transforms.InterpolationMode.BICUBIC),
], ArgsKwargs(fill=1),
# ElasticTransform needs larger images to avoid the needed internal padding being larger than the actual image ],
make_images_kwargs=dict(DEFAULT_MAKE_IMAGES_KWARGS, sizes=[(163, 163), (72, 333), (313, 95)]), # ElasticTransform needs larger images to avoid the needed internal padding being larger than the actual image
), make_images_kwargs=dict(DEFAULT_MAKE_IMAGES_KWARGS, sizes=[(163, 163), (72, 333), (313, 95)], dtypes=[dt]),
# We updated gaussian blur kernel generation with a faster and numerically more stable version
# This brings float32 accumulation visible in elastic transform -> we need to relax consistency tolerance
closeness_kwargs=ckw,
)
for dt, ckw in [(torch.uint8, {"rtol": 1e-1, "atol": 1}), (torch.float32, {"rtol": 1e-2, "atol": 1e-3})]
],
ConsistencyConfig( ConsistencyConfig(
prototype_transforms.GaussianBlur, prototype_transforms.GaussianBlur,
legacy_transforms.GaussianBlur, legacy_transforms.GaussianBlur,
...@@ -333,6 +339,7 @@ CONSISTENCY_CONFIGS = [ ...@@ -333,6 +339,7 @@ CONSISTENCY_CONFIGS = [
ArgsKwargs(kernel_size=3, sigma=0.7), ArgsKwargs(kernel_size=3, sigma=0.7),
ArgsKwargs(kernel_size=5, sigma=(0.3, 1.4)), ArgsKwargs(kernel_size=5, sigma=(0.3, 1.4)),
], ],
closeness_kwargs={"rtol": 1e-5, "atol": 1e-5},
), ),
ConsistencyConfig( ConsistencyConfig(
prototype_transforms.RandomAffine, prototype_transforms.RandomAffine,
...@@ -506,7 +513,6 @@ def check_call_consistency( ...@@ -506,7 +513,6 @@ def check_call_consistency(
image_repr = f"[{tuple(image.shape)}, {str(image.dtype).rsplit('.')[-1]}]" image_repr = f"[{tuple(image.shape)}, {str(image.dtype).rsplit('.')[-1]}]"
image_tensor = torch.Tensor(image) image_tensor = torch.Tensor(image)
try: try:
torch.manual_seed(0) torch.manual_seed(0)
output_legacy_tensor = legacy_transform(image_tensor) output_legacy_tensor = legacy_transform(image_tensor)
......
import math
from typing import List, Optional, Union from typing import List, Optional, Union
import PIL.Image import PIL.Image
import torch import torch
from torch.nn.functional import conv2d, pad as torch_pad
from torchvision.prototype import features from torchvision.prototype import features
from torchvision.transforms import functional_tensor as _FT from torchvision.transforms import functional_tensor as _FT
from torchvision.transforms.functional import pil_to_tensor, to_pil_image from torchvision.transforms.functional import pil_to_tensor, to_pil_image
...@@ -32,6 +34,22 @@ def normalize( ...@@ -32,6 +34,22 @@ def normalize(
return normalize_image_tensor(inpt, mean=mean, std=std, inplace=inplace) return normalize_image_tensor(inpt, mean=mean, std=std, inplace=inplace)
def _get_gaussian_kernel1d(kernel_size: int, sigma: float) -> torch.Tensor:
lim = (kernel_size - 1) / (2 * math.sqrt(2) * sigma)
x = torch.linspace(-lim, lim, steps=kernel_size)
kernel1d = torch.softmax(-x.pow_(2), dim=0)
return kernel1d
def _get_gaussian_kernel2d(
kernel_size: List[int], sigma: List[float], dtype: torch.dtype, device: torch.device
) -> torch.Tensor:
kernel1d_x = _get_gaussian_kernel1d(kernel_size[0], sigma[0]).to(device, dtype=dtype)
kernel1d_y = _get_gaussian_kernel1d(kernel_size[1], sigma[1]).to(device, dtype=dtype)
kernel2d = kernel1d_y.unsqueeze(-1) * kernel1d_x
return kernel2d
def gaussian_blur_image_tensor( def gaussian_blur_image_tensor(
image: torch.Tensor, kernel_size: List[int], sigma: Optional[List[float]] = None image: torch.Tensor, kernel_size: List[int], sigma: Optional[List[float]] = None
) -> torch.Tensor: ) -> torch.Tensor:
...@@ -70,7 +88,18 @@ def gaussian_blur_image_tensor( ...@@ -70,7 +88,18 @@ def gaussian_blur_image_tensor(
else: else:
needs_unsquash = False needs_unsquash = False
output = _FT.gaussian_blur(image, kernel_size, sigma) dtype = image.dtype if torch.is_floating_point(image) else torch.float32
kernel = _get_gaussian_kernel2d(kernel_size, sigma, dtype=dtype, device=image.device)
kernel = kernel.expand(image.shape[-3], 1, kernel.shape[0], kernel.shape[1])
image, need_cast, need_squeeze, out_dtype = _FT._cast_squeeze_in(image, [kernel.dtype])
# padding = (left, right, top, bottom)
padding = [kernel_size[0] // 2, kernel_size[0] // 2, kernel_size[1] // 2, kernel_size[1] // 2]
output = torch_pad(image, padding, mode="reflect")
output = conv2d(output, kernel, groups=output.shape[-3])
output = _FT._cast_squeeze_out(output, need_cast, need_squeeze, out_dtype)
if needs_unsquash: if needs_unsquash:
output = output.reshape(shape) output = output.reshape(shape)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment