Unverified Commit 095cabb7 authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

port image type conversion transforms to prototype API (#5640)



* port image type conversion transforms to prototype API

* implement proposal for image type conversion

* add deprecation warnings

* appease mypy
Co-authored-by: default avatarVasilis Vryniotis <datumbox@users.noreply.github.com>
parent a7746efb
......@@ -330,6 +330,10 @@ def rotate_segmentation_mask():
and callable(kernel)
and any(feature_type in name for feature_type in {"image", "segmentation_mask", "bounding_box", "label"})
and "pil" not in name
and name
not in {
"to_image_tensor",
}
],
)
def test_scriptable(kernel):
......
......@@ -22,4 +22,4 @@ from ._meta import ConvertBoundingBoxFormat, ConvertImageDtype, ConvertImageColo
from ._misc import Identity, Normalize, ToDtype, Lambda
from ._type_conversion import DecodeImage, LabelToOneHot
from ._legacy import Grayscale, RandomGrayscale # usort: skip
from ._deprecated import Grayscale, RandomGrayscale, ToTensor, ToPILImage, PILToTensor # usort: skip
from __future__ import annotations
import warnings
from typing import Any, Dict
from typing import Any, Dict, Optional
import numpy as np
import PIL.Image
from torchvision.prototype import features
from torchvision.prototype.features import ColorSpace
from torchvision.prototype.transforms import Transform
from torchvision.transforms import functional as _F
from typing_extensions import Literal
from ._meta import ConvertImageColorSpace
from ._transform import _RandomApplyTransform
from ._utils import is_simple_tensor
class ToTensor(Transform):
def __init__(self) -> None:
warnings.warn(
"The transform `ToTensor()` is deprecated and will be removed in a future release. "
"Instead, please use `transforms.ToImageTensor()`."
)
super().__init__()
def _transform(self, input: Any, params: Dict[str, Any]) -> Any:
if isinstance(input, (PIL.Image.Image, np.ndarray)):
return _F.to_tensor(input)
else:
return input
class PILToTensor(Transform):
def __init__(self) -> None:
warnings.warn(
"The transform `PILToTensor()` is deprecated and will be removed in a future release. "
"Instead, please use `transforms.ToImageTensor()`."
)
super().__init__()
def _transform(self, input: Any, params: Dict[str, Any]) -> Any:
if isinstance(input, PIL.Image.Image):
return _F.pil_to_tensor(input)
else:
return input
class ToPILImage(Transform):
def __init__(self, mode: Optional[str] = None) -> None:
warnings.warn(
"The transform `ToPILImage()` is deprecated and will be removed in a future release. "
"Instead, please use `transforms.ToImagePIL()`."
)
super().__init__()
self.mode = mode
def _transform(self, input: Any, params: Dict[str, Any]) -> Any:
if is_simple_tensor(input) or isinstance(input, (features.Image, np.ndarray)):
return _F.to_pil_image(input, mode=self.mode)
else:
return input
class Grayscale(Transform):
......
from typing import Any, Dict
import numpy as np
import PIL.Image
from torchvision.prototype import features
from torchvision.prototype.transforms import Transform, functional as F
from ._utils import is_simple_tensor
class DecodeImage(Transform):
def _transform(self, input: Any, params: Dict[str, Any]) -> Any:
......@@ -33,3 +37,28 @@ class LabelToOneHot(Transform):
return ""
return f"num_categories={self.num_categories}"
class ToImageTensor(Transform):
def __init__(self, *, copy: bool = False) -> None:
super().__init__()
self.copy = copy
def _transform(self, input: Any, params: Dict[str, Any]) -> Any:
if isinstance(input, (features.Image, PIL.Image.Image, np.ndarray)) or is_simple_tensor(input):
output = F.to_image_tensor(input, copy=self.copy)
return features.Image(output)
else:
return input
class ToImagePIL(Transform):
def __init__(self, *, copy: bool = False) -> None:
super().__init__()
self.copy = copy
def _transform(self, input: Any, params: Dict[str, Any]) -> Any:
if isinstance(input, (features.Image, PIL.Image.Image, np.ndarray)) or is_simple_tensor(input):
return F.to_image_pil(input, copy=self.copy)
else:
return input
......@@ -74,4 +74,10 @@ from ._geometry import (
ten_crop_image_pil,
)
from ._misc import normalize_image_tensor, gaussian_blur_image_tensor
from ._type_conversion import decode_image_with_pil, decode_video_with_av, label_to_one_hot
from ._type_conversion import (
decode_image_with_pil,
decode_video_with_av,
label_to_one_hot,
to_image_tensor,
to_image_pil,
)
import unittest.mock
from typing import Dict, Any, Tuple
from typing import Dict, Any, Tuple, Union
import numpy as np
import PIL.Image
......@@ -7,6 +7,7 @@ import torch
from torch.nn.functional import one_hot
from torchvision.io.video import read_video
from torchvision.prototype.utils._internal import ReadOnlyTensorBuffer
from torchvision.transforms import functional as _F
def decode_image_with_pil(encoded_image: torch.Tensor) -> torch.Tensor:
......@@ -23,3 +24,23 @@ def decode_video_with_av(encoded_video: torch.Tensor) -> Tuple[torch.Tensor, tor
def label_to_one_hot(label: torch.Tensor, *, num_categories: int) -> torch.Tensor:
return one_hot(label, num_classes=num_categories) # type: ignore[no-any-return]
def to_image_tensor(image: Union[torch.Tensor, PIL.Image.Image, np.ndarray], copy: bool = False) -> torch.Tensor:
if isinstance(image, torch.Tensor):
if copy:
return image.clone()
else:
return image
return _F.to_tensor(image)
def to_image_pil(image: Union[torch.Tensor, PIL.Image.Image, np.ndarray], copy: bool = False) -> PIL.Image.Image:
if isinstance(image, PIL.Image.Image):
if copy:
return image.copy()
else:
return image
return _F.to_pil_image(to_image_tensor(image, copy=False))
......@@ -120,7 +120,7 @@ def _is_numpy_image(img: Any) -> bool:
return img.ndim in {2, 3}
def to_tensor(pic):
def to_tensor(pic) -> Tensor:
"""Convert a ``PIL Image`` or ``numpy.ndarray`` to tensor.
This function does not support torchscript.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment