"git@developer.sourcefind.cn:OpenDAS/ollama.git" did not exist on "6db8a3771c29d070ef165cca0d7e8dbda3fc341e"
Unverified Commit 9559188c authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

allow grayscale tensor images in `rgb_to_grayscale` (#6474)

* add deprecated color conversion functionals

* allow grayscale tensor inputs in rgb_to_grayscale

* add cloning to tensor no-op

* improve todo comment

* [skip ci]

* use legacy kernels

use legacy kernels

* fix import

* remove duplicate code

* remove duplicate check
parent b4b246a5
...@@ -2287,5 +2287,17 @@ def test_elastic_transformation(): ...@@ -2287,5 +2287,17 @@ def test_elastic_transformation():
t.__repr__() t.__repr__()
def test_random_grayscale_with_grayscale_input():
transform = transforms.RandomGrayscale(p=1.0)
image_tensor = torch.randint(0, 256, (1, 16, 16), dtype=torch.uint8)
output_tensor = transform(image_tensor)
torch.testing.assert_close(output_tensor, image_tensor)
image_pil = F.to_pil_image(image_tensor)
output_pil = transform(image_pil)
torch.testing.assert_close(F.pil_to_tensor(output_pil), image_tensor)
if __name__ == "__main__": if __name__ == "__main__":
pytest.main([__file__]) pytest.main([__file__])
...@@ -65,5 +65,8 @@ def has_all(sample: Any, *types_or_checks: Union[Type, Callable[[Any], bool]]) - ...@@ -65,5 +65,8 @@ def has_all(sample: Any, *types_or_checks: Union[Type, Callable[[Any], bool]]) -
return True return True
# TODO: Given that this is not related to pytree / the Transform object, we should probably move it to somewhere else.
# One possibility is `functional._utils` so both the functionals and the transforms have proper access to it. We could
# also move it `features` since it literally checks for the _Feature type.
def is_simple_tensor(inpt: Any) -> bool: def is_simple_tensor(inpt: Any) -> bool:
return isinstance(inpt, torch.Tensor) and not isinstance(inpt, features._Feature) return isinstance(inpt, torch.Tensor) and not isinstance(inpt, features._Feature)
...@@ -113,3 +113,5 @@ from ._type_conversion import ( ...@@ -113,3 +113,5 @@ from ._type_conversion import (
to_image_pil, to_image_pil,
to_image_tensor, to_image_tensor,
) )
from ._deprecated import rgb_to_grayscale, to_grayscale # usort: skip
import warnings
from typing import Any
import PIL.Image
from torchvision.prototype import features
from torchvision.transforms import functional as _F
from .._utils import is_simple_tensor
def to_grayscale(inpt: PIL.Image.Image, num_output_channels: int = 1) -> PIL.Image.Image:
call = ", num_output_channels=3" if num_output_channels == 3 else ""
replacement = "convert_color_space(..., color_space=features.ColorSpace.GRAY)"
if num_output_channels == 3:
replacement = f"convert_color_space({replacement}, color_space=features.ColorSpace.RGB)"
warnings.warn(
f"The function `to_grayscale(...{call})` is deprecated in will be removed in a future release. "
f"Instead, please use `{replacement}`.",
)
return _F.to_grayscale(inpt, num_output_channels=num_output_channels)
def rgb_to_grayscale(inpt: Any, num_output_channels: int = 1) -> Any:
old_color_space = features.Image.guess_color_space(inpt) if is_simple_tensor(inpt) else None
call = ", num_output_channels=3" if num_output_channels == 3 else ""
replacement = (
f"convert_color_space(..., color_space=features.ColorSpace.GRAY"
f"{f', old_color_space=features.ColorSpace.{old_color_space}' if old_color_space is not None else ''})"
)
if num_output_channels == 3:
replacement = (
f"convert_color_space({replacement}, color_space=features.ColorSpace.RGB"
f"{f', old_color_space=features.ColorSpace.GRAY' if old_color_space is not None else ''})"
)
warnings.warn(
f"The function `rgb_to_grayscale(...{call})` is deprecated in will be removed in a future release. "
f"Instead, please use `{replacement}`.",
)
return _F.rgb_to_grayscale(inpt, num_output_channels=num_output_channels)
...@@ -145,16 +145,19 @@ def crop(img: Tensor, top: int, left: int, height: int, width: int) -> Tensor: ...@@ -145,16 +145,19 @@ def crop(img: Tensor, top: int, left: int, height: int, width: int) -> Tensor:
def rgb_to_grayscale(img: Tensor, num_output_channels: int = 1) -> Tensor: def rgb_to_grayscale(img: Tensor, num_output_channels: int = 1) -> Tensor:
if img.ndim < 3: if img.ndim < 3:
raise TypeError(f"Input image tensor should have at least 3 dimensions, but found {img.ndim}") raise TypeError(f"Input image tensor should have at least 3 dimensions, but found {img.ndim}")
_assert_channels(img, [3]) _assert_channels(img, [1, 3])
if num_output_channels not in (1, 3): if num_output_channels not in (1, 3):
raise ValueError("num_output_channels should be either 1 or 3") raise ValueError("num_output_channels should be either 1 or 3")
r, g, b = img.unbind(dim=-3) if img.shape[-3] == 3:
# This implementation closely follows the TF one: r, g, b = img.unbind(dim=-3)
# https://github.com/tensorflow/tensorflow/blob/v2.3.0/tensorflow/python/ops/image_ops_impl.py#L2105-L2138 # This implementation closely follows the TF one:
l_img = (0.2989 * r + 0.587 * g + 0.114 * b).to(img.dtype) # https://github.com/tensorflow/tensorflow/blob/v2.3.0/tensorflow/python/ops/image_ops_impl.py#L2105-L2138
l_img = l_img.unsqueeze(dim=-3) l_img = (0.2989 * r + 0.587 * g + 0.114 * b).to(img.dtype)
l_img = l_img.unsqueeze(dim=-3)
else:
l_img = img.clone()
if num_output_channels == 3: if num_output_channels == 3:
return l_img.expand(img.shape) return l_img.expand(img.shape)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment