Unverified Commit d5091562 authored by Nicolas Hug's avatar Nicolas Hug Committed by GitHub
Browse files

Undeprecate ToGrayScale transforms and functionals (#7122)

parent 60c78f28
......@@ -149,6 +149,8 @@ CONSISTENCY_CONFIGS = [
ArgsKwargs(num_output_channels=3),
],
make_images_kwargs=dict(DEFAULT_MAKE_IMAGES_KWARGS, color_spaces=["RGB", "GRAY"]),
# Use default tolerances of `torch.testing.assert_close`
closeness_kwargs=dict(rtol=None, atol=None),
),
ConsistencyConfig(
prototype_transforms.ConvertDtype,
......@@ -271,6 +273,9 @@ CONSISTENCY_CONFIGS = [
ArgsKwargs(p=0),
ArgsKwargs(p=1),
],
make_images_kwargs=dict(DEFAULT_MAKE_IMAGES_KWARGS, color_spaces=["RGB", "GRAY"]),
# Use default tolerances of `torch.testing.assert_close`
closeness_kwargs=dict(rtol=None, atol=None),
),
ConsistencyConfig(
prototype_transforms.RandomResizedCrop,
......
......@@ -230,6 +230,9 @@ class Datapoint(torch.Tensor):
) -> Datapoint:
return self
def to_grayscale(self, num_output_channels: int = 1) -> Datapoint:
return self
def adjust_brightness(self, brightness_factor: float) -> Datapoint:
return self
......
......@@ -169,6 +169,12 @@ class Image(Datapoint):
)
return Image.wrap_like(self, output)
def to_grayscale(self, num_output_channels: int = 1) -> Image:
output = self._F.rgb_to_grayscale_image_tensor(
self.as_subclass(torch.Tensor), num_output_channels=num_output_channels
)
return Image.wrap_like(self, output)
def adjust_brightness(self, brightness_factor: float) -> Image:
output = self._F.adjust_brightness_image_tensor(
self.as_subclass(torch.Tensor), brightness_factor=brightness_factor
......
......@@ -173,6 +173,12 @@ class Video(Datapoint):
)
return Video.wrap_like(self, output)
def to_grayscale(self, num_output_channels: int = 1) -> Video:
output = self._F.rgb_to_grayscale_image_tensor(
self.as_subclass(torch.Tensor), num_output_channels=num_output_channels
)
return Video.wrap_like(self, output)
def adjust_brightness(self, brightness_factor: float) -> Video:
output = self._F.adjust_brightness_video(self.as_subclass(torch.Tensor), brightness_factor=brightness_factor)
return Video.wrap_like(self, output)
......
......@@ -9,9 +9,11 @@ from ._augment import RandomCutmix, RandomErasing, RandomMixup, SimpleCopyPaste
from ._auto_augment import AugMix, AutoAugment, RandAugment, TrivialAugmentWide
from ._color import (
ColorJitter,
Grayscale,
RandomAdjustSharpness,
RandomAutocontrast,
RandomEqualize,
RandomGrayscale,
RandomInvert,
RandomPhotometricDistort,
RandomPosterize,
......@@ -54,4 +56,4 @@ from ._misc import (
from ._temporal import UniformTemporalSubsample
from ._type_conversion import LabelToOneHot, PILToTensor, ToImagePIL, ToImageTensor, ToPILImage
from ._deprecated import Grayscale, RandomGrayscale, ToTensor # usort: skip
from ._deprecated import ToTensor # usort: skip
......@@ -11,6 +11,41 @@ from ._transform import _RandomApplyTransform
from .utils import is_simple_tensor, query_chw
class Grayscale(Transform):
_transformed_types = (
datapoints.Image,
PIL.Image.Image,
is_simple_tensor,
datapoints.Video,
)
def __init__(self, num_output_channels: int = 1):
super().__init__()
self.num_output_channels = num_output_channels
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
return F.rgb_to_grayscale(inpt, num_output_channels=self.num_output_channels)
class RandomGrayscale(_RandomApplyTransform):
_transformed_types = (
datapoints.Image,
PIL.Image.Image,
is_simple_tensor,
datapoints.Video,
)
def __init__(self, p: float = 0.1) -> None:
super().__init__(p=p)
def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
num_input_channels, *_ = query_chw(flat_inputs)
return dict(num_input_channels=num_input_channels)
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
return F.rgb_to_grayscale(inpt, num_output_channels=params["num_input_channels"])
class ColorJitter(Transform):
def __init__(
self,
......
import warnings
from typing import Any, Dict, List, Union
from typing import Any, Dict, Union
import numpy as np
import PIL.Image
import torch
from torchvision.prototype import datapoints
from torchvision.prototype.transforms import Transform
from torchvision.transforms import functional as _F
from typing_extensions import Literal
from ._transform import _RandomApplyTransform
from .utils import is_simple_tensor, query_chw
class ToTensor(Transform):
......@@ -26,78 +21,3 @@ class ToTensor(Transform):
def _transform(self, inpt: Union[PIL.Image.Image, np.ndarray], params: Dict[str, Any]) -> torch.Tensor:
return _F.to_tensor(inpt)
# TODO: in other PR (?) undeprecate those and make them use _rgb_to_gray?
class Grayscale(Transform):
_transformed_types = (
datapoints.Image,
PIL.Image.Image,
is_simple_tensor,
datapoints.Video,
)
def __init__(self, num_output_channels: Literal[1, 3] = 1) -> None:
deprecation_msg = (
f"The transform `Grayscale(num_output_channels={num_output_channels})` "
f"is deprecated and will be removed in a future release."
)
if num_output_channels == 1:
replacement_msg = (
"transforms.ConvertImageColorSpace(old_color_space=ColorSpace.RGB, color_space=ColorSpace.GRAY)"
)
else:
replacement_msg = (
"transforms.Compose(\n"
" transforms.ConvertImageColorSpace(old_color_space=ColorSpace.RGB, color_space=ColorSpace.GRAY),\n"
" transforms.ConvertImageColorSpace(old_color_space=ColorSpace.GRAY, color_space=ColorSpace.RGB),\n"
")"
)
warnings.warn(f"{deprecation_msg} Instead, please use\n\n{replacement_msg}")
super().__init__()
self.num_output_channels = num_output_channels
def _transform(
self, inpt: Union[datapoints.ImageType, datapoints.VideoType], params: Dict[str, Any]
) -> Union[datapoints.ImageType, datapoints.VideoType]:
output = _F.rgb_to_grayscale(inpt, num_output_channels=self.num_output_channels)
if isinstance(inpt, (datapoints.Image, datapoints.Video)):
output = inpt.wrap_like(inpt, output) # type: ignore[arg-type]
return output
class RandomGrayscale(_RandomApplyTransform):
_transformed_types = (
datapoints.Image,
PIL.Image.Image,
is_simple_tensor,
datapoints.Video,
)
def __init__(self, p: float = 0.1) -> None:
warnings.warn(
"The transform `RandomGrayscale(p=...)` is deprecated and will be removed in a future release. "
"Instead, please use\n\n"
"transforms.RandomApply(\n"
" transforms.Compose(\n"
" transforms.ConvertImageColorSpace(old_color_space=ColorSpace.RGB, color_space=ColorSpace.GRAY),\n"
" transforms.ConvertImageColorSpace(old_color_space=ColorSpace.GRAY, color_space=ColorSpace.RGB),\n"
" )\n"
" p=...,\n"
")"
)
super().__init__(p=p)
def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
num_input_channels, *_ = query_chw(flat_inputs)
return dict(num_input_channels=num_input_channels)
def _transform(
self, inpt: Union[datapoints.ImageType, datapoints.VideoType], params: Dict[str, Any]
) -> Union[datapoints.ImageType, datapoints.VideoType]:
output = _F.rgb_to_grayscale(inpt, num_output_channels=params["num_input_channels"])
if isinstance(inpt, (datapoints.Image, datapoints.Video)):
output = inpt.wrap_like(inpt, output) # type: ignore[arg-type]
return output
......@@ -71,6 +71,9 @@ from ._color import (
posterize_image_pil,
posterize_image_tensor,
posterize_video,
rgb_to_grayscale,
rgb_to_grayscale_image_pil,
rgb_to_grayscale_image_tensor,
solarize,
solarize_image_pil,
solarize_image_tensor,
......@@ -167,4 +170,4 @@ from ._misc import (
from ._temporal import uniform_temporal_subsample, uniform_temporal_subsample_video
from ._type_conversion import pil_to_tensor, to_image_pil, to_image_tensor, to_pil_image
from ._deprecated import get_image_size, rgb_to_grayscale, to_grayscale, to_tensor # usort: skip
from ._deprecated import get_image_size, to_grayscale, to_tensor # usort: skip
from typing import Union
import PIL.Image
import torch
from torch.nn.functional import conv2d
......@@ -7,10 +9,53 @@ from torchvision.transforms.functional_tensor import _max_value
from torchvision.utils import _log_api_usage_once
from ._meta import _num_value_bits, _rgb_to_gray, convert_dtype_image_tensor
from ._meta import _num_value_bits, convert_dtype_image_tensor
from ._utils import is_simple_tensor
def _rgb_to_grayscale_image_tensor(
image: torch.Tensor, num_output_channels: int = 1, preserve_dtype: bool = True
) -> torch.Tensor:
if image.shape[-3] == 1:
return image.clone()
r, g, b = image.unbind(dim=-3)
l_img = r.mul(0.2989).add_(g, alpha=0.587).add_(b, alpha=0.114)
l_img = l_img.unsqueeze(dim=-3)
if preserve_dtype:
l_img = l_img.to(image.dtype)
if num_output_channels == 3:
l_img = l_img.expand(image.shape)
return l_img
def rgb_to_grayscale_image_tensor(image: torch.Tensor, num_output_channels: int = 1) -> torch.Tensor:
return _rgb_to_grayscale_image_tensor(image, num_output_channels=num_output_channels, preserve_dtype=True)
rgb_to_grayscale_image_pil = _FP.to_grayscale
def rgb_to_grayscale(
inpt: Union[datapoints.ImageTypeJIT, datapoints.VideoTypeJIT], num_output_channels: int = 1
) -> Union[datapoints.ImageTypeJIT, datapoints.VideoTypeJIT]:
if not torch.jit.is_scripting():
_log_api_usage_once(rgb_to_grayscale)
if num_output_channels not in (1, 3):
raise ValueError(f"num_output_channels must be 1 or 3, got {num_output_channels}.")
if torch.jit.is_scripting() or is_simple_tensor(inpt):
return rgb_to_grayscale_image_tensor(inpt, num_output_channels=num_output_channels)
elif isinstance(inpt, datapoints._datapoint.Datapoint):
return inpt.to_grayscale(num_output_channels=num_output_channels)
elif isinstance(inpt, PIL.Image.Image):
return rgb_to_grayscale_image_pil(inpt, num_output_channels=num_output_channels)
else:
raise TypeError(
f"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, "
f"but got {type(inpt)} instead."
)
def _blend(image1: torch.Tensor, image2: torch.Tensor, ratio: float) -> torch.Tensor:
ratio = float(ratio)
fp = image1.is_floating_point()
......@@ -68,7 +113,7 @@ def adjust_saturation_image_tensor(image: torch.Tensor, saturation_factor: float
if c == 1: # Match PIL behaviour
return image
grayscale_image = _rgb_to_gray(image, cast=False)
grayscale_image = _rgb_to_grayscale_image_tensor(image, num_output_channels=1, preserve_dtype=False)
if not image.is_floating_point():
grayscale_image = grayscale_image.floor_()
......@@ -110,7 +155,7 @@ def adjust_contrast_image_tensor(image: torch.Tensor, contrast_factor: float) ->
raise TypeError(f"Input image tensor permitted channel values are 1 or 3, but found {c}")
fp = image.is_floating_point()
if c == 3:
grayscale_image = _rgb_to_gray(image, cast=False)
grayscale_image = _rgb_to_grayscale_image_tensor(image, num_output_channels=1, preserve_dtype=False)
if not fp:
grayscale_image = grayscale_image.floor_()
else:
......
......@@ -7,8 +7,6 @@ import torch
from torchvision.prototype import datapoints
from torchvision.transforms import functional as _F
from ._utils import is_simple_tensor
@torch.jit.unused
def to_grayscale(inpt: PIL.Image.Image, num_output_channels: int = 1) -> PIL.Image.Image:
......@@ -24,33 +22,6 @@ def to_grayscale(inpt: PIL.Image.Image, num_output_channels: int = 1) -> PIL.Ima
return _F.to_grayscale(inpt, num_output_channels=num_output_channels)
def rgb_to_grayscale(
inpt: Union[datapoints.ImageTypeJIT, datapoints.VideoTypeJIT], num_output_channels: int = 1
) -> Union[datapoints.ImageTypeJIT, datapoints.VideoTypeJIT]:
old_color_space = None # TODO: remove when un-deprecating
if not (torch.jit.is_scripting() or is_simple_tensor(inpt)) and isinstance(
inpt, (datapoints.Image, datapoints.Video)
):
inpt = inpt.as_subclass(torch.Tensor)
call = ", num_output_channels=3" if num_output_channels == 3 else ""
replacement = (
f"convert_color_space(..., color_space=datapoints.ColorSpace.GRAY"
f"{f', old_color_space=datapoints.ColorSpace.{old_color_space}' if old_color_space is not None else ''})"
)
if num_output_channels == 3:
replacement = (
f"convert_color_space({replacement}, color_space=datapoints.ColorSpace.RGB"
f"{f', old_color_space=datapoints.ColorSpace.GRAY' if old_color_space is not None else ''})"
)
warnings.warn(
f"The function `rgb_to_grayscale(...{call})` is deprecated in will be removed in a future release. "
f"Instead, please use `{replacement}`.",
)
return _F.rgb_to_grayscale(inpt, num_output_channels=num_output_channels)
@torch.jit.unused
def to_tensor(inpt: Any) -> torch.Tensor:
warnings.warn(
......
......@@ -225,15 +225,6 @@ def clamp_bounding_box(
return convert_format_bounding_box(xyxy_boxes, old_format=BoundingBoxFormat.XYXY, new_format=format, inplace=True)
def _rgb_to_gray(image: torch.Tensor, cast: bool = True) -> torch.Tensor:
r, g, b = image.unbind(dim=-3)
l_img = r.mul(0.2989).add_(g, alpha=0.587).add_(b, alpha=0.114)
if cast:
l_img = l_img.to(image.dtype)
l_img = l_img.unsqueeze(dim=-3)
return l_img
def _num_value_bits(dtype: torch.dtype) -> int:
if dtype == torch.uint8:
return 8
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment