"git@developer.sourcefind.cn:OpenDAS/torchaudio.git" did not exist on "716aa416053fbfaf468db357e2215429c098d975"
Unverified Commit 10d47a66 authored by Vasilis Vryniotis's avatar Vasilis Vryniotis Committed by GitHub
Browse files

[prototype] Speed up `adjust_contrast_image_tensor` (#6933)

* Avoid double casting on adjust_contrast

* Handle properly ints.
parent f32600b6
...@@ -79,9 +79,14 @@ def adjust_contrast_image_tensor(image: torch.Tensor, contrast_factor: float) -> ...@@ -79,9 +79,14 @@ def adjust_contrast_image_tensor(image: torch.Tensor, contrast_factor: float) ->
c = image.shape[-3] c = image.shape[-3]
if c not in [1, 3]: if c not in [1, 3]:
raise TypeError(f"Input image tensor permitted channel values are {[1, 3]}, but found {c}") raise TypeError(f"Input image tensor permitted channel values are {[1, 3]}, but found {c}")
dtype = image.dtype if torch.is_floating_point(image) else torch.float32 fp = image.is_floating_point()
grayscale_image = _rgb_to_gray(image) if c == 3 else image if c == 3:
mean = torch.mean(grayscale_image.to(dtype), dim=(-3, -2, -1), keepdim=True) grayscale_image = _rgb_to_gray(image, cast=False)
if not fp:
grayscale_image = grayscale_image.floor_()
else:
grayscale_image = image if fp else image.to(torch.float32)
mean = torch.mean(grayscale_image, dim=(-3, -2, -1), keepdim=True)
return _blend(image, mean, contrast_factor) return _blend(image, mean, contrast_factor)
......
...@@ -213,10 +213,12 @@ def _gray_to_rgb(grayscale: torch.Tensor) -> torch.Tensor: ...@@ -213,10 +213,12 @@ def _gray_to_rgb(grayscale: torch.Tensor) -> torch.Tensor:
return grayscale.repeat(repeats) return grayscale.repeat(repeats)
def _rgb_to_gray(image: torch.Tensor) -> torch.Tensor: def _rgb_to_gray(image: torch.Tensor, cast: bool = True) -> torch.Tensor:
r, g, b = image.unbind(dim=-3) r, g, b = image.unbind(dim=-3)
l_img = (0.2989 * r).add_(g, alpha=0.587).add_(b, alpha=0.114) l_img = r.mul(0.2989).add_(g, alpha=0.587).add_(b, alpha=0.114)
l_img = l_img.to(image.dtype).unsqueeze(dim=-3) if cast:
l_img = l_img.to(image.dtype)
l_img = l_img.unsqueeze(dim=-3)
return l_img return l_img
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment