Unverified Commit edb3a806 authored by Vasilis Vryniotis's avatar Vasilis Vryniotis Committed by GitHub
Browse files

[prototype] Minor improvements on functional (#6832)

* Minor improvements on functional.

* Restore `_split_alpha`.

* Revert "Restore `_split_alpha`."

This reverts commit 2286120be6d4af2a3c9b52b605d87611ec70fe06.
parent b45969a7
......@@ -188,7 +188,7 @@ def _hsv_to_rgb(img: torch.Tensor) -> torch.Tensor:
h, s, v = img.unbind(dim=-3)
h6 = h * 6
i = torch.floor(h6)
f = (h6) - i
f = h6 - i
i = i.to(dtype=torch.int32)
p = (v * (1.0 - s)).clamp_(0.0, 1.0)
......@@ -210,9 +210,6 @@ def adjust_hue_image_tensor(image: torch.Tensor, hue_factor: float) -> torch.Ten
if not (-0.5 <= hue_factor <= 0.5):
raise ValueError(f"hue_factor ({hue_factor}) is not in [-0.5, 0.5].")
if not (isinstance(image, torch.Tensor)):
raise TypeError("Input img should be Tensor image")
c = get_num_channels_image_tensor(image)
if c not in [1, 3]:
......@@ -258,9 +255,6 @@ def adjust_hue(inpt: features.InputTypeJIT, hue_factor: float) -> features.Input
def adjust_gamma_image_tensor(image: torch.Tensor, gamma: float, gain: float = 1.0) -> torch.Tensor:
if not (isinstance(image, torch.Tensor)):
raise TypeError("Input img should be Tensor image")
if gamma < 0:
raise ValueError("Gamma should be a non-negative real number")
......@@ -337,10 +331,6 @@ def solarize(inpt: features.InputTypeJIT, threshold: float) -> features.InputTyp
def autocontrast_image_tensor(image: torch.Tensor) -> torch.Tensor:
if not (isinstance(image, torch.Tensor)):
raise TypeError("Input img should be Tensor image")
c = get_num_channels_image_tensor(image)
if c not in [1, 3]:
......
......@@ -183,12 +183,8 @@ def clamp_bounding_box(
return convert_format_bounding_box(xyxy_boxes, BoundingBoxFormat.XYXY, format)
def _split_alpha(image: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
return image[..., :-1, :, :], image[..., -1:, :, :]
def _strip_alpha(image: torch.Tensor) -> torch.Tensor:
image, alpha = _split_alpha(image)
image, alpha = torch.tensor_split(image, indices=(-1,), dim=-3)
if not torch.all(alpha == _FT._max_value(alpha.dtype)):
raise RuntimeError(
"Stripping the alpha channel if it contains values other than the max value is not supported."
......@@ -237,7 +233,7 @@ def convert_color_space_image_tensor(
elif old_color_space == ColorSpace.GRAY_ALPHA and new_color_space == ColorSpace.RGB:
return _gray_to_rgb(_strip_alpha(image))
elif old_color_space == ColorSpace.GRAY_ALPHA and new_color_space == ColorSpace.RGB_ALPHA:
image, alpha = _split_alpha(image)
image, alpha = torch.tensor_split(image, indices=(-1,), dim=-3)
return _add_alpha(_gray_to_rgb(image), alpha)
elif old_color_space == ColorSpace.RGB and new_color_space == ColorSpace.GRAY:
return _rgb_to_gray(image)
......@@ -248,7 +244,7 @@ def convert_color_space_image_tensor(
elif old_color_space == ColorSpace.RGB_ALPHA and new_color_space == ColorSpace.GRAY:
return _rgb_to_gray(_strip_alpha(image))
elif old_color_space == ColorSpace.RGB_ALPHA and new_color_space == ColorSpace.GRAY_ALPHA:
image, alpha = _split_alpha(image)
image, alpha = torch.tensor_split(image, indices=(-1,), dim=-3)
return _add_alpha(_rgb_to_gray(image), alpha)
elif old_color_space == ColorSpace.RGB_ALPHA and new_color_space == ColorSpace.RGB:
return _strip_alpha(image)
......
......@@ -67,9 +67,9 @@ def normalize(
return normalize_image_tensor(inpt, mean=mean, std=std, inplace=inplace)
def _get_gaussian_kernel1d(kernel_size: int, sigma: float) -> torch.Tensor:
def _get_gaussian_kernel1d(kernel_size: int, sigma: float, dtype: torch.dtype, device: torch.device) -> torch.Tensor:
lim = (kernel_size - 1) / (2 * math.sqrt(2) * sigma)
x = torch.linspace(-lim, lim, steps=kernel_size)
x = torch.linspace(-lim, lim, steps=kernel_size, dtype=dtype, device=device)
kernel1d = torch.softmax(-x.pow_(2), dim=0)
return kernel1d
......@@ -77,8 +77,8 @@ def _get_gaussian_kernel1d(kernel_size: int, sigma: float) -> torch.Tensor:
def _get_gaussian_kernel2d(
kernel_size: List[int], sigma: List[float], dtype: torch.dtype, device: torch.device
) -> torch.Tensor:
kernel1d_x = _get_gaussian_kernel1d(kernel_size[0], sigma[0]).to(device, dtype=dtype)
kernel1d_y = _get_gaussian_kernel1d(kernel_size[1], sigma[1]).to(device, dtype=dtype)
kernel1d_x = _get_gaussian_kernel1d(kernel_size[0], sigma[0], dtype, device)
kernel1d_y = _get_gaussian_kernel1d(kernel_size[1], sigma[1], dtype, device)
kernel2d = kernel1d_y.unsqueeze(-1) * kernel1d_x
return kernel2d
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment