"src/array/vscode:/vscode.git/clone" did not exist on "c59000ac3ad12ad5e1f769267742d10174de4921"
Unverified Commit 2bababf2 authored by ahmadsharif1's avatar ahmadsharif1 Committed by GitHub
Browse files

Add a GrayscaleToRgb transform that can expand channels to 3 (#8247)

parent fa82fd3b
...@@ -347,6 +347,7 @@ Color ...@@ -347,6 +347,7 @@ Color
v2.RandomChannelPermutation v2.RandomChannelPermutation
v2.RandomPhotometricDistort v2.RandomPhotometricDistort
v2.Grayscale v2.Grayscale
v2.RGB
v2.RandomGrayscale v2.RandomGrayscale
v2.GaussianBlur v2.GaussianBlur
v2.RandomInvert v2.RandomInvert
...@@ -364,6 +365,7 @@ Functionals ...@@ -364,6 +365,7 @@ Functionals
v2.functional.permute_channels v2.functional.permute_channels
v2.functional.rgb_to_grayscale v2.functional.rgb_to_grayscale
v2.functional.grayscale_to_rgb
v2.functional.to_grayscale v2.functional.to_grayscale
v2.functional.gaussian_blur v2.functional.gaussian_blur
v2.functional.invert v2.functional.invert
......
...@@ -5005,6 +5005,54 @@ class TestRgbToGrayscale: ...@@ -5005,6 +5005,54 @@ class TestRgbToGrayscale:
assert_equal(actual, expected, rtol=0, atol=1) assert_equal(actual, expected, rtol=0, atol=1)
class TestGrayscaleToRgb:
@pytest.mark.parametrize("dtype", [torch.uint8, torch.float32])
@pytest.mark.parametrize("device", cpu_and_cuda())
def test_kernel_image(self, dtype, device):
check_kernel(F.grayscale_to_rgb_image, make_image(dtype=dtype, device=device))
@pytest.mark.parametrize("make_input", [make_image_tensor, make_image_pil, make_image])
def test_functional(self, make_input):
check_functional(F.grayscale_to_rgb, make_input())
@pytest.mark.parametrize(
("kernel", "input_type"),
[
(F.rgb_to_grayscale_image, torch.Tensor),
(F._rgb_to_grayscale_image_pil, PIL.Image.Image),
(F.rgb_to_grayscale_image, tv_tensors.Image),
],
)
def test_functional_signature(self, kernel, input_type):
check_functional_kernel_signature_match(F.grayscale_to_rgb, kernel=kernel, input_type=input_type)
@pytest.mark.parametrize("make_input", [make_image_tensor, make_image_pil, make_image])
def test_transform(self, make_input):
check_transform(transforms.RGB(), make_input(color_space="GRAY"))
@pytest.mark.parametrize("fn", [F.grayscale_to_rgb, transform_cls_to_functional(transforms.RGB)])
def test_image_correctness(self, fn):
image = make_image(dtype=torch.uint8, device="cpu", color_space="GRAY")
actual = fn(image)
expected = F.to_image(F.grayscale_to_rgb(F.to_pil_image(image)))
assert_equal(actual, expected, rtol=0, atol=1)
def test_expanded_channels_are_not_views_into_the_same_underlying_tensor(self):
image = make_image(dtype=torch.uint8, device="cpu", color_space="GRAY")
output_image = F.grayscale_to_rgb(image)
assert_equal(output_image[0][0][0], output_image[1][0][0])
output_image[0][0][0] = output_image[0][0][0] + 1
assert output_image[0][0][0] != output_image[1][0][0]
def test_rgb_image_is_unchanged(self):
image = make_image(dtype=torch.uint8, device="cpu", color_space="RGB")
assert_equal(image.shape[-3], 3)
assert_equal(F.grayscale_to_rgb(image), image)
class TestRandomZoomOut: class TestRandomZoomOut:
# Tests are light because this largely relies on the already tested `pad` kernels. # Tests are light because this largely relies on the already tested `pad` kernels.
......
...@@ -18,6 +18,7 @@ from ._color import ( ...@@ -18,6 +18,7 @@ from ._color import (
RandomPhotometricDistort, RandomPhotometricDistort,
RandomPosterize, RandomPosterize,
RandomSolarize, RandomSolarize,
RGB,
) )
from ._container import Compose, RandomApply, RandomChoice, RandomOrder from ._container import Compose, RandomApply, RandomChoice, RandomOrder
from ._geometry import ( from ._geometry import (
......
...@@ -54,6 +54,20 @@ class RandomGrayscale(_RandomApplyTransform): ...@@ -54,6 +54,20 @@ class RandomGrayscale(_RandomApplyTransform):
return self._call_kernel(F.rgb_to_grayscale, inpt, num_output_channels=params["num_input_channels"]) return self._call_kernel(F.rgb_to_grayscale, inpt, num_output_channels=params["num_input_channels"])
class RGB(Transform):
"""Convert images or videos to RGB (if they are already not RGB).
If the input is a :class:`torch.Tensor`, it is expected
to have [..., 1 or 3, H, W] shape, where ... means an arbitrary number of leading dimensions
"""
def __init__(self):
super().__init__()
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
return self._call_kernel(F.grayscale_to_rgb, inpt)
class ColorJitter(Transform): class ColorJitter(Transform):
"""Randomly change the brightness, contrast, saturation and hue of an image or video. """Randomly change the brightness, contrast, saturation and hue of an image or video.
......
...@@ -63,6 +63,8 @@ from ._color import ( ...@@ -63,6 +63,8 @@ from ._color import (
equalize, equalize,
equalize_image, equalize_image,
equalize_video, equalize_video,
grayscale_to_rgb,
grayscale_to_rgb_image,
invert, invert,
invert_image, invert_image,
invert_video, invert_video,
......
...@@ -65,6 +65,32 @@ def _rgb_to_grayscale_image_pil(image: PIL.Image.Image, num_output_channels: int ...@@ -65,6 +65,32 @@ def _rgb_to_grayscale_image_pil(image: PIL.Image.Image, num_output_channels: int
return _FP.to_grayscale(image, num_output_channels=num_output_channels) return _FP.to_grayscale(image, num_output_channels=num_output_channels)
def grayscale_to_rgb(inpt: torch.Tensor) -> torch.Tensor:
"""See :class:`~torchvision.transforms.v2.GrayscaleToRgb` for details."""
if torch.jit.is_scripting():
return grayscale_to_rgb_image(inpt)
_log_api_usage_once(grayscale_to_rgb)
kernel = _get_kernel(grayscale_to_rgb, type(inpt))
return kernel(inpt)
@_register_kernel_internal(grayscale_to_rgb, torch.Tensor)
@_register_kernel_internal(grayscale_to_rgb, tv_tensors.Image)
def grayscale_to_rgb_image(image: torch.Tensor) -> torch.Tensor:
if image.shape[-3] >= 3:
# Image already has RGB channels. We don't need to do anything.
return image
# rgb_to_grayscale can be used to add channels so we reuse that function.
return _rgb_to_grayscale_image(image, num_output_channels=3, preserve_dtype=True)
@_register_kernel_internal(grayscale_to_rgb, PIL.Image.Image)
def grayscale_to_rgb_image_pil(image: PIL.Image.Image) -> PIL.Image.Image:
return image.convert(mode="RGB")
def _blend(image1: torch.Tensor, image2: torch.Tensor, ratio: float) -> torch.Tensor: def _blend(image1: torch.Tensor, image2: torch.Tensor, ratio: float) -> torch.Tensor:
ratio = float(ratio) ratio = float(ratio)
fp = image1.is_floating_point() fp = image1.is_floating_point()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment