"docs/vscode:/vscode.git/clone" did not exist on "ce363d0e06500f1b449b47cb80b648913bca8716"
Unverified Commit 10d47a66 authored by Vasilis Vryniotis's avatar Vasilis Vryniotis Committed by GitHub
Browse files

[prototype] Speed up `adjust_contrast_image_tensor` (#6933)

* Avoid double casting on adjust_contrast

* Handle properly ints.
parent f32600b6
......@@ -79,9 +79,14 @@ def adjust_contrast_image_tensor(image: torch.Tensor, contrast_factor: float) ->
c = image.shape[-3]
if c not in [1, 3]:
raise TypeError(f"Input image tensor permitted channel values are {[1, 3]}, but found {c}")
dtype = image.dtype if torch.is_floating_point(image) else torch.float32
grayscale_image = _rgb_to_gray(image) if c == 3 else image
mean = torch.mean(grayscale_image.to(dtype), dim=(-3, -2, -1), keepdim=True)
fp = image.is_floating_point()
if c == 3:
grayscale_image = _rgb_to_gray(image, cast=False)
if not fp:
grayscale_image = grayscale_image.floor_()
else:
grayscale_image = image if fp else image.to(torch.float32)
mean = torch.mean(grayscale_image, dim=(-3, -2, -1), keepdim=True)
return _blend(image, mean, contrast_factor)
......
......@@ -213,10 +213,12 @@ def _gray_to_rgb(grayscale: torch.Tensor) -> torch.Tensor:
return grayscale.repeat(repeats)
def _rgb_to_gray(image: torch.Tensor) -> torch.Tensor:
def _rgb_to_gray(image: torch.Tensor, cast: bool = True) -> torch.Tensor:
r, g, b = image.unbind(dim=-3)
l_img = (0.2989 * r).add_(g, alpha=0.587).add_(b, alpha=0.114)
l_img = l_img.to(image.dtype).unsqueeze(dim=-3)
l_img = r.mul(0.2989).add_(g, alpha=0.587).add_(b, alpha=0.114)
if cast:
l_img = l_img.to(image.dtype)
l_img = l_img.unsqueeze(dim=-3)
return l_img
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment