Unverified Commit 7245dc9e authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

More cleanup for prototype transforms (#6500)

* add aliases for hflip and vflip

* reduce imports from torchvision.transforms in torchvision.prototype.transforms

* add aliases for to_pil_image abd pil_to_tensor

* deprecate to_tensor

* add some FIXME cleanup comments

* address reviews

* add dimension getters

* undeprecate PILToTensor and ToPILImage

* address review

* fix test
parent 7de63171
......@@ -1071,10 +1071,9 @@ class TestToPILImage:
[torch.Tensor, PIL.Image.Image, features.Image, np.ndarray, features.BoundingBox, str, int],
)
def test__transform(self, inpt_type, mocker):
fn = mocker.patch("torchvision.transforms.functional.to_pil_image")
fn = mocker.patch("torchvision.prototype.transforms.functional.to_image_pil")
inpt = mocker.MagicMock(spec=inpt_type)
with pytest.warns(UserWarning, match="deprecated and will be removed"):
transform = transforms.ToPILImage()
transform(inpt)
if inpt_type in (PIL.Image.Image, features.BoundingBox, str, int):
......
......@@ -674,6 +674,8 @@ def erase_image_tensor():
and name
not in {
"to_image_tensor",
"get_image_num_channels",
"get_image_size",
}
],
)
......
from torchvision.transforms import AutoAugmentPolicy, InterpolationMode # usort: skip
from . import functional # usort: skip
from ._transform import Transform # usort: skip
from ._augment import RandomCutmix, RandomErasing, RandomMixup, SimpleCopyPaste
from ._auto_augment import AugMix, AutoAugment, AutoAugmentPolicy, RandAugment, TrivialAugmentWide
from ._auto_augment import AugMix, AutoAugment, RandAugment, TrivialAugmentWide
from ._color import (
ColorJitter,
RandomAdjustSharpness,
......@@ -37,6 +39,6 @@ from ._geometry import (
)
from ._meta import ClampBoundingBoxes, ConvertBoundingBoxFormat, ConvertColorSpace, ConvertImageDtype
from ._misc import GaussianBlur, Identity, Lambda, LinearTransformation, Normalize, RemoveSmallBoundingBoxes, ToDtype
from ._type_conversion import DecodeImage, LabelToOneHot, ToImagePIL, ToImageTensor
from ._type_conversion import DecodeImage, LabelToOneHot, PILToTensor, ToImagePIL, ToImageTensor, ToPILImage
from ._deprecated import Grayscale, RandomGrayscale, ToTensor, ToPILImage, PILToTensor # usort: skip
from ._deprecated import Grayscale, RandomGrayscale, ToTensor # usort: skip
......@@ -8,9 +8,7 @@ import torch
from torch.utils._pytree import tree_flatten, tree_unflatten
from torchvision.ops import masks_to_boxes
from torchvision.prototype import features
from torchvision.prototype.transforms import functional as F
from torchvision.transforms.functional import InterpolationMode, pil_to_tensor
from torchvision.prototype.transforms import functional as F, InterpolationMode
from ._transform import _RandomApplyTransform
from ._utils import has_any, query_chw
......@@ -279,7 +277,7 @@ class SimpleCopyPaste(_RandomApplyTransform):
if isinstance(obj, features.Image) or features.is_simple_tensor(obj):
images.append(obj)
elif isinstance(obj, PIL.Image.Image):
images.append(pil_to_tensor(obj))
images.append(F.to_image_tensor(obj))
elif isinstance(obj, features.BoundingBox):
bboxes.append(obj)
elif isinstance(obj, features.SegmentationMask):
......
......@@ -7,11 +7,10 @@ import torch
from torch.utils._pytree import tree_flatten, tree_unflatten
from torchvision.prototype import features
from torchvision.prototype.transforms import functional as F, Transform
from torchvision.transforms.autoaugment import AutoAugmentPolicy
from torchvision.transforms.functional import InterpolationMode, pil_to_tensor, to_pil_image
from torchvision.prototype.transforms import AutoAugmentPolicy, functional as F, InterpolationMode, Transform
from torchvision.prototype.transforms.functional._meta import get_chw
from ._utils import _isinstance, get_chw
from ._utils import _isinstance
K = TypeVar("K")
V = TypeVar("V")
......@@ -473,7 +472,7 @@ class AugMix(_AutoAugmentBase):
if isinstance(orig_image, torch.Tensor):
image = orig_image
else: # isinstance(inpt, PIL.Image.Image):
image = pil_to_tensor(orig_image)
image = F.to_image_tensor(orig_image)
augmentation_space = self._AUGMENTATION_SPACE if self.all_ops else self._PARTIAL_AUGMENTATION_SPACE
......@@ -516,6 +515,6 @@ class AugMix(_AutoAugmentBase):
if isinstance(orig_image, features.Image):
mix = features.Image.new_like(orig_image, mix)
elif isinstance(orig_image, PIL.Image.Image):
mix = to_pil_image(mix)
mix = F.to_image_pil(mix)
return self._put_into_sample(sample, id, mix)
......@@ -5,7 +5,6 @@ import PIL.Image
import torch
from torchvision.prototype import features
from torchvision.prototype.transforms import functional as F, Transform
from torchvision.transforms import functional as _F
from ._transform import _RandomApplyTransform
from ._utils import query_chw
......@@ -85,6 +84,8 @@ class ColorJitter(Transform):
class RandomPhotometricDistort(Transform):
_transformed_types = (features.Image, PIL.Image.Image, features.is_simple_tensor)
def __init__(
self,
contrast: Tuple[float, float] = (0.5, 1.5),
......@@ -112,19 +113,15 @@ class RandomPhotometricDistort(Transform):
)
def _permute_channels(self, inpt: Any, *, permutation: torch.Tensor) -> Any:
if not (isinstance(inpt, (features.Image, PIL.Image.Image)) or features.is_simple_tensor(inpt)):
return inpt
image = inpt
if isinstance(inpt, PIL.Image.Image):
image = _F.pil_to_tensor(image)
inpt = F.to_image_tensor(inpt)
output = image[..., permutation, :, :]
output = inpt[..., permutation, :, :]
if isinstance(inpt, features.Image):
output = features.Image.new_like(inpt, output, color_space=features.ColorSpace.OTHER)
elif isinstance(inpt, PIL.Image.Image):
output = _F.to_pil_image(output)
output = F.to_image_pil(output)
return output
......
import warnings
from typing import Any, Dict, Optional
from typing import Any, Dict
import numpy as np
import PIL.Image
......@@ -20,7 +20,7 @@ class ToTensor(Transform):
def __init__(self) -> None:
warnings.warn(
"The transform `ToTensor()` is deprecated and will be removed in a future release. "
"Instead, please use `transforms.ToImageTensor()`."
"Instead, please use `transforms.Compose([transforms.ToImageTensor(), transforms.ConvertImageDtype()])`."
)
super().__init__()
......@@ -28,35 +28,6 @@ class ToTensor(Transform):
return _F.to_tensor(inpt)
class PILToTensor(Transform):
_transformed_types = (PIL.Image.Image,)
def __init__(self) -> None:
warnings.warn(
"The transform `PILToTensor()` is deprecated and will be removed in a future release. "
"Instead, please use `transforms.ToImageTensor()`."
)
super().__init__()
def _transform(self, inpt: Any, params: Dict[str, Any]) -> torch.Tensor:
return _F.pil_to_tensor(inpt)
class ToPILImage(Transform):
_transformed_types = (features.is_simple_tensor, features.Image, np.ndarray)
def __init__(self, mode: Optional[str] = None) -> None:
warnings.warn(
"The transform `ToPILImage()` is deprecated and will be removed in a future release. "
"Instead, please use `transforms.ToImagePIL()`."
)
super().__init__()
self.mode = mode
def _transform(self, inpt: Any, params: Dict[str, Any]) -> PIL.Image.Image:
return _F.to_pil_image(inpt, mode=self.mode)
class Grayscale(Transform):
_transformed_types = (features.Image, PIL.Image.Image, features.is_simple_tensor)
......
......@@ -7,15 +7,21 @@ import PIL.Image
import torch
from torchvision.ops.boxes import box_iou
from torchvision.prototype import features
from torchvision.prototype.transforms import functional as F, Transform
from torchvision.transforms.functional import InterpolationMode
from torchvision.transforms.functional_tensor import _parse_pad_padding
from torchvision.transforms.transforms import _check_sequence_input, _setup_angle, _setup_size
from torchvision.prototype.transforms import functional as F, InterpolationMode, Transform
from typing_extensions import Literal
from ._transform import _RandomApplyTransform
from ._utils import has_all, has_any, query_bounding_box, query_chw
from ._utils import (
_check_sequence_input,
_parse_pad_padding,
_setup_angle,
_setup_size,
has_all,
has_any,
query_bounding_box,
query_chw,
)
class RandomHorizontalFlip(_RandomApplyTransform):
......
......@@ -5,7 +5,6 @@ import PIL.Image
import torch
from torchvision.prototype import features
from torchvision.prototype.transforms import functional as F, Transform
from torchvision.transforms.functional import convert_image_dtype
class ConvertBoundingBoxFormat(Transform):
......@@ -30,7 +29,7 @@ class ConvertImageDtype(Transform):
self.dtype = dtype
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
output = convert_image_dtype(inpt, dtype=self.dtype)
output = F.convert_image_dtype(inpt, dtype=self.dtype)
return output if features.is_simple_tensor(inpt) else features.Image.new_like(inpt, output, dtype=self.dtype)
......
......@@ -7,8 +7,8 @@ import torch
from torchvision.ops import remove_small_boxes
from torchvision.prototype import features
from torchvision.prototype.transforms import functional as F, Transform
from torchvision.prototype.transforms._utils import has_any, query_bounding_box
from torchvision.transforms.transforms import _setup_size
from ._utils import _setup_size, has_any, query_bounding_box
class Identity(Transform):
......
......@@ -52,3 +52,9 @@ class ToImagePIL(Transform):
def _transform(self, inpt: Any, params: Dict[str, Any]) -> PIL.Image.Image:
return F.to_image_pil(inpt, mode=self.mode)
# We changed the names to align them with the new naming scheme. Still, `PILToTensor` and `ToPILImage` are
# prevalent and well understood. Thus, we just alias them without deprecating the old names.
PILToTensor = ToImageTensor
ToPILImage = ToImagePIL
from typing import Any, Callable, Tuple, Type, Union
import PIL.Image
import torch
from torch.utils._pytree import tree_flatten
from torchvision._utils import sequence_to_str
from torchvision.prototype import features
from .functional._meta import get_dimensions_image_pil, get_dimensions_image_tensor
from torchvision.prototype.transforms.functional._meta import get_chw
from torchvision.transforms.functional_tensor import _parse_pad_padding # noqa: F401
from torchvision.transforms.transforms import _check_sequence_input, _setup_angle, _setup_size # noqa: F401
def query_bounding_box(sample: Any) -> features.BoundingBox:
......@@ -19,19 +20,6 @@ def query_bounding_box(sample: Any) -> features.BoundingBox:
return bounding_boxes.pop()
def get_chw(image: Union[PIL.Image.Image, torch.Tensor, features.Image]) -> Tuple[int, int, int]:
if isinstance(image, features.Image):
channels = image.num_channels
height, width = image.image_size
elif features.is_simple_tensor(image):
channels, height, width = get_dimensions_image_tensor(image)
elif isinstance(image, PIL.Image.Image):
channels, height, width = get_dimensions_image_pil(image)
else:
raise TypeError(f"unable to get image dimensions from object of type {type(image).__name__}")
return channels, height, width
def query_chw(sample: Any) -> Tuple[int, int, int]:
flat_sample, _ = tree_flatten(sample)
chws = {
......
......@@ -5,6 +5,9 @@ from ._meta import (
convert_color_space_image_tensor,
convert_color_space_image_pil,
convert_color_space,
get_dimensions,
get_image_num_channels,
get_image_size,
) # usort: skip
from ._augment import erase, erase_image_pil, erase_image_tensor
......@@ -68,6 +71,7 @@ from ._geometry import (
five_crop,
five_crop_image_pil,
five_crop_image_tensor,
hflip,
horizontal_flip,
horizontal_flip_bounding_box,
horizontal_flip_image_pil,
......@@ -106,8 +110,17 @@ from ._geometry import (
vertical_flip_image_pil,
vertical_flip_image_tensor,
vertical_flip_segmentation_mask,
vflip,
)
from ._misc import gaussian_blur, gaussian_blur_image_pil, gaussian_blur_image_tensor, normalize, normalize_image_tensor
from ._type_conversion import decode_image_with_pil, decode_video_with_av, to_image_pil, to_image_tensor
from ._type_conversion import (
convert_image_dtype,
decode_image_with_pil,
decode_video_with_av,
pil_to_tensor,
to_image_pil,
to_image_tensor,
to_pil_image,
)
from ._deprecated import rgb_to_grayscale, to_grayscale # usort: skip
......@@ -2,6 +2,7 @@ import warnings
from typing import Any
import PIL.Image
import torch
from torchvision.prototype import features
from torchvision.transforms import functional as _F
......@@ -39,3 +40,11 @@ def rgb_to_grayscale(inpt: Any, num_output_channels: int = 1) -> Any:
)
return _F.rgb_to_grayscale(inpt, num_output_channels=num_output_channels)
def to_tensor(inpt: Any) -> torch.Tensor:
warnings.warn(
"The function `to_tensor(...)` is deprecated and will be removed in a future release. "
"Instead, please use `to_image_tensor(...)` followed by `convert_image_dtype(...)`."
)
return _F.to_tensor(inpt)
......@@ -89,6 +89,12 @@ def vertical_flip(inpt: DType) -> DType:
return vertical_flip_image_tensor(inpt)
# We changed the names to align them with the transforms, i.e. `RandomHorizontalFlip`. Still, `hflip` and `vflip` are
# prevalent and well understood. Thus, we just alias them without deprecating the old names.
hflip = horizontal_flip
vflip = vertical_flip
def resize_image_tensor(
image: torch.Tensor,
size: List[int],
......
from typing import Any, Optional, Tuple
from typing import Any, List, Optional, Tuple, Union
import PIL.Image
import torch
from torchvision.prototype import features
from torchvision.prototype.features import BoundingBoxFormat, ColorSpace, Image
from torchvision.transforms import functional_pil as _FP, functional_tensor as _FT
......@@ -9,6 +10,40 @@ get_dimensions_image_tensor = _FT.get_dimensions
get_dimensions_image_pil = _FP.get_dimensions
def get_chw(image: Union[PIL.Image.Image, torch.Tensor, features.Image]) -> Tuple[int, int, int]:
if isinstance(image, features.Image):
channels = image.num_channels
height, width = image.image_size
elif features.is_simple_tensor(image):
channels, height, width = get_dimensions_image_tensor(image)
elif isinstance(image, PIL.Image.Image):
channels, height, width = get_dimensions_image_pil(image)
else:
raise TypeError(f"unable to get image dimensions from object of type {type(image).__name__}")
return channels, height, width
# The three functions below are here for BC. Whether we want to have two different kernels and how they and the
# compound version should be named is still under discussion: https://github.com/pytorch/vision/issues/6491
# Given that these kernels should also support boxes, masks, and videos, it is unlikely that there name will stay.
# They will either be deprecated or simply aliased to the new kernels if we have reached consensus about the issue
# detailed above.
def get_dimensions(image: Union[PIL.Image.Image, torch.Tensor, features.Image]) -> List[int]:
return list(get_chw(image))
def get_image_num_channels(image: Union[PIL.Image.Image, torch.Tensor, features.Image]) -> int:
num_channels, *_ = get_chw(image)
return num_channels
def get_image_size(image: Union[PIL.Image.Image, torch.Tensor, features.Image]) -> List[int]:
_, *image_size = get_chw(image)
return image_size
def _xywh_to_xyxy(xywh: torch.Tensor) -> torch.Tensor:
xyxy = xywh.clone()
xyxy[..., 2:] += xyxy[..., :2]
......
......@@ -31,3 +31,10 @@ def to_image_tensor(image: Union[torch.Tensor, PIL.Image.Image, np.ndarray]) ->
to_image_pil = _F.to_pil_image
# We changed the names to align them with the new naming scheme. Still, `to_pil_image` and `pil_to_tensor` are
# prevalent and well understood. Thus, we just alias them without deprecating the old names.
to_pil_image = to_image_pil
pil_to_tensor = to_image_tensor
convert_image_dtype = _F.convert_image_dtype
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment