Unverified Commit f725901d authored by Vasilis Vryniotis's avatar Vasilis Vryniotis Committed by GitHub
Browse files

[prototype] Align and Clean up transform types (#6627)



* Align and Clean up transform types

* Move type definitions to `_utils.py`

* fixing error message on tests

* Apply code review suggestions
Co-authored-by: default avatarvfdev <vfdev.5@gmail.com>

* Centralizing types and switching to always getting dicts.

* Fixing linter

* Refactoring typing definitions.

* Remove relative imports.

* Reuse type.

* Temporarily remove the TorchData tests.

* Restore the TorchData tests.
Co-authored-by: default avatarvfdev <vfdev.5@gmail.com>
parent 6b1646ca
......@@ -132,7 +132,7 @@ class TestSmoke:
transform(input_copy)
# Check if we raise an error if sample contains bbox or mask or label
err_msg = "does not support bounding boxes, masks and plain labels"
err_msg = "does not support PIL images, bounding boxes, masks and plain labels"
input_copy = dict(input)
for unsup_data in [
make_label(),
......
from ._bounding_box import BoundingBox, BoundingBoxFormat
from ._encoded import EncodedData, EncodedImage, EncodedVideo
from ._feature import _Feature, DType, is_simple_tensor
from ._image import ColorSpace, Image, ImageType
from ._feature import _Feature, FillType, FillTypeJIT, InputType, InputTypeJIT, is_simple_tensor
from ._image import (
ColorSpace,
Image,
ImageType,
ImageTypeJIT,
LegacyImageType,
LegacyImageTypeJIT,
TensorImageType,
TensorImageTypeJIT,
)
from ._label import Label, OneHotLabel
from ._mask import Mask
......@@ -6,7 +6,7 @@ import torch
from torchvision._utils import StrEnum
from torchvision.transforms import InterpolationMode # TODO: this needs to be moved out of transforms
from ._feature import _Feature
from ._feature import _Feature, FillTypeJIT
class BoundingBoxFormat(StrEnum):
......@@ -115,7 +115,7 @@ class BoundingBox(_Feature):
def pad(
self,
padding: Union[int, Sequence[int]],
fill: Optional[Union[int, float, List[float]]] = None,
fill: FillTypeJIT = None,
padding_mode: str = "constant",
) -> BoundingBox:
# This cast does Sequence[int] -> List[int] and is required to make mypy happy
......@@ -137,7 +137,7 @@ class BoundingBox(_Feature):
angle: float,
interpolation: InterpolationMode = InterpolationMode.NEAREST,
expand: bool = False,
fill: Optional[Union[int, float, List[float]]] = None,
fill: FillTypeJIT = None,
center: Optional[List[float]] = None,
) -> BoundingBox:
output = self._F.rotate_bounding_box(
......@@ -165,7 +165,7 @@ class BoundingBox(_Feature):
scale: float,
shear: List[float],
interpolation: InterpolationMode = InterpolationMode.NEAREST,
fill: Optional[Union[int, float, List[float]]] = None,
fill: FillTypeJIT = None,
center: Optional[List[float]] = None,
) -> BoundingBox:
output = self._F.affine_bounding_box(
......@@ -184,7 +184,7 @@ class BoundingBox(_Feature):
self,
perspective_coeffs: List[float],
interpolation: InterpolationMode = InterpolationMode.BILINEAR,
fill: Optional[Union[int, float, List[float]]] = None,
fill: FillTypeJIT = None,
) -> BoundingBox:
output = self._F.perspective_bounding_box(self, self.format, perspective_coeffs)
return BoundingBox.new_like(self, output, dtype=output.dtype)
......@@ -193,7 +193,7 @@ class BoundingBox(_Feature):
self,
displacement: torch.Tensor,
interpolation: InterpolationMode = InterpolationMode.BILINEAR,
fill: Optional[Union[int, float, List[float]]] = None,
fill: FillTypeJIT = None,
) -> BoundingBox:
output = self._F.elastic_bounding_box(self, self.format, displacement)
return BoundingBox.new_like(self, output, dtype=output.dtype)
......@@ -3,16 +3,14 @@ from __future__ import annotations
from types import ModuleType
from typing import Any, Callable, cast, List, Mapping, Optional, Sequence, Tuple, Type, TypeVar, Union
import PIL.Image
import torch
from torch._C import _TensorBase, DisableTorchFunction
from torchvision.transforms import InterpolationMode
F = TypeVar("F", bound="_Feature")
# Due to torch.jit.script limitation we keep DType as torch.Tensor
# instead of Union[torch.Tensor, PIL.Image.Image, features._Feature]
DType = torch.Tensor
FillType = Union[int, float, Sequence[int], Sequence[float], None]
FillTypeJIT = Union[int, float, List[float], None]
def is_simple_tensor(inpt: Any) -> bool:
......@@ -154,7 +152,7 @@ class _Feature(torch.Tensor):
def pad(
self,
padding: Union[int, List[int]],
fill: Optional[Union[int, float, List[float]]] = None,
fill: FillTypeJIT = None,
padding_mode: str = "constant",
) -> _Feature:
return self
......@@ -164,7 +162,7 @@ class _Feature(torch.Tensor):
angle: float,
interpolation: InterpolationMode = InterpolationMode.NEAREST,
expand: bool = False,
fill: Optional[Union[int, float, List[float]]] = None,
fill: FillTypeJIT = None,
center: Optional[List[float]] = None,
) -> _Feature:
return self
......@@ -176,7 +174,7 @@ class _Feature(torch.Tensor):
scale: float,
shear: List[float],
interpolation: InterpolationMode = InterpolationMode.NEAREST,
fill: Optional[Union[int, float, List[float]]] = None,
fill: FillTypeJIT = None,
center: Optional[List[float]] = None,
) -> _Feature:
return self
......@@ -185,7 +183,7 @@ class _Feature(torch.Tensor):
self,
perspective_coeffs: List[float],
interpolation: InterpolationMode = InterpolationMode.BILINEAR,
fill: Optional[Union[int, float, List[float]]] = None,
fill: FillTypeJIT = None,
) -> _Feature:
return self
......@@ -193,7 +191,7 @@ class _Feature(torch.Tensor):
self,
displacement: torch.Tensor,
interpolation: InterpolationMode = InterpolationMode.BILINEAR,
fill: Optional[Union[int, float, List[float]]] = None,
fill: FillTypeJIT = None,
) -> _Feature:
return self
......@@ -232,3 +230,7 @@ class _Feature(torch.Tensor):
def gaussian_blur(self, kernel_size: List[int], sigma: Optional[List[float]] = None) -> _Feature:
return self
InputType = Union[torch.Tensor, PIL.Image.Image, _Feature]
InputTypeJIT = torch.Tensor
......@@ -3,18 +3,14 @@ from __future__ import annotations
import warnings
from typing import Any, cast, List, Optional, Tuple, Union
import PIL.Image
import torch
from torchvision._utils import StrEnum
from torchvision.transforms.functional import InterpolationMode, to_pil_image
from torchvision.utils import draw_bounding_boxes, make_grid
from ._bounding_box import BoundingBox
from ._feature import _Feature
# Due to torch.jit.script limitation we keep ImageType as torch.Tensor
# instead of Union[torch.Tensor, PIL.Image.Image, features.Image]
ImageType = torch.Tensor
from ._feature import _Feature, FillTypeJIT
class ColorSpace(StrEnum):
......@@ -181,7 +177,7 @@ class Image(_Feature):
def pad(
self,
padding: Union[int, List[int]],
fill: Optional[Union[int, float, List[float]]] = None,
fill: FillTypeJIT = None,
padding_mode: str = "constant",
) -> Image:
output = self._F.pad_image_tensor(self, padding, fill=fill, padding_mode=padding_mode)
......@@ -192,7 +188,7 @@ class Image(_Feature):
angle: float,
interpolation: InterpolationMode = InterpolationMode.NEAREST,
expand: bool = False,
fill: Optional[Union[int, float, List[float]]] = None,
fill: FillTypeJIT = None,
center: Optional[List[float]] = None,
) -> Image:
output = self._F._geometry.rotate_image_tensor(
......@@ -207,7 +203,7 @@ class Image(_Feature):
scale: float,
shear: List[float],
interpolation: InterpolationMode = InterpolationMode.NEAREST,
fill: Optional[Union[int, float, List[float]]] = None,
fill: FillTypeJIT = None,
center: Optional[List[float]] = None,
) -> Image:
output = self._F._geometry.affine_image_tensor(
......@@ -226,7 +222,7 @@ class Image(_Feature):
self,
perspective_coeffs: List[float],
interpolation: InterpolationMode = InterpolationMode.BILINEAR,
fill: Optional[Union[int, float, List[float]]] = None,
fill: FillTypeJIT = None,
) -> Image:
output = self._F._geometry.perspective_image_tensor(
self, perspective_coeffs, interpolation=interpolation, fill=fill
......@@ -237,7 +233,7 @@ class Image(_Feature):
self,
displacement: torch.Tensor,
interpolation: InterpolationMode = InterpolationMode.BILINEAR,
fill: Optional[Union[int, float, List[float]]] = None,
fill: FillTypeJIT = None,
) -> Image:
output = self._F._geometry.elastic_image_tensor(self, displacement, interpolation=interpolation, fill=fill)
return Image.new_like(self, output)
......@@ -289,3 +285,11 @@ class Image(_Feature):
def gaussian_blur(self, kernel_size: List[int], sigma: Optional[List[float]] = None) -> Image:
output = self._F.gaussian_blur_image_tensor(self, kernel_size=kernel_size, sigma=sigma)
return Image.new_like(self, output)
ImageType = Union[torch.Tensor, PIL.Image.Image, Image]
ImageTypeJIT = torch.Tensor
LegacyImageType = Union[torch.Tensor, PIL.Image.Image]
LegacyImageTypeJIT = torch.Tensor
TensorImageType = Union[torch.Tensor, Image]
TensorImageTypeJIT = torch.Tensor
......@@ -5,7 +5,7 @@ from typing import List, Optional, Union
import torch
from torchvision.transforms import InterpolationMode
from ._feature import _Feature
from ._feature import _Feature, FillTypeJIT
class Mask(_Feature):
......@@ -51,7 +51,7 @@ class Mask(_Feature):
def pad(
self,
padding: Union[int, List[int]],
fill: Optional[Union[int, float, List[float]]] = None,
fill: FillTypeJIT = None,
padding_mode: str = "constant",
) -> Mask:
output = self._F.pad_mask(self, padding, padding_mode=padding_mode, fill=fill)
......@@ -62,7 +62,7 @@ class Mask(_Feature):
angle: float,
interpolation: InterpolationMode = InterpolationMode.NEAREST,
expand: bool = False,
fill: Optional[Union[int, float, List[float]]] = None,
fill: FillTypeJIT = None,
center: Optional[List[float]] = None,
) -> Mask:
output = self._F.rotate_mask(self, angle, expand=expand, center=center, fill=fill)
......@@ -75,7 +75,7 @@ class Mask(_Feature):
scale: float,
shear: List[float],
interpolation: InterpolationMode = InterpolationMode.NEAREST,
fill: Optional[Union[int, float, List[float]]] = None,
fill: FillTypeJIT = None,
center: Optional[List[float]] = None,
) -> Mask:
output = self._F.affine_mask(
......@@ -93,7 +93,7 @@ class Mask(_Feature):
self,
perspective_coeffs: List[float],
interpolation: InterpolationMode = InterpolationMode.NEAREST,
fill: Optional[Union[int, float, List[float]]] = None,
fill: FillTypeJIT = None,
) -> Mask:
output = self._F.perspective_mask(self, perspective_coeffs, fill=fill)
return Mask.new_like(self, output)
......@@ -102,7 +102,7 @@ class Mask(_Feature):
self,
displacement: torch.Tensor,
interpolation: InterpolationMode = InterpolationMode.NEAREST,
fill: Optional[Union[int, float, List[float]]] = None,
fill: FillTypeJIT = None,
) -> Mask:
output = self._F.elastic_mask(self, displacement, fill=fill)
return Mask.new_like(self, output, dtype=output.dtype)
import math
import numbers
import warnings
from typing import Any, Dict, List, Optional, Tuple, Union
from typing import Any, cast, Dict, List, Optional, Tuple
import PIL.Image
import torch
......@@ -92,9 +92,7 @@ class RandomErasing(_RandomApplyTransform):
return dict(i=i, j=j, h=h, w=w, v=v)
def _transform(
self, inpt: Union[torch.Tensor, features.Image, PIL.Image.Image], params: Dict[str, Any]
) -> Union[torch.Tensor, features.Image, PIL.Image.Image]:
def _transform(self, inpt: features.ImageType, params: Dict[str, Any]) -> features.ImageType:
if params["v"] is not None:
inpt = F.erase(inpt, **params, inplace=self.inplace)
......@@ -110,8 +108,10 @@ class _BaseMixupCutmix(_RandomApplyTransform):
def forward(self, *inputs: Any) -> Any:
if not (has_any(inputs, features.Image, features.is_simple_tensor) and has_any(inputs, features.OneHotLabel)):
raise TypeError(f"{type(self).__name__}() is only defined for tensor images and one-hot labels.")
if has_any(inputs, features.BoundingBox, features.Mask, features.Label):
raise TypeError(f"{type(self).__name__}() does not support bounding boxes, masks and plain labels.")
if has_any(inputs, PIL.Image.Image, features.BoundingBox, features.Mask, features.Label):
raise TypeError(
f"{type(self).__name__}() does not support PIL images, bounding boxes, masks and plain labels."
)
return super().forward(*inputs)
def _mixup_onehotlabel(self, inpt: features.OneHotLabel, lam: float) -> features.OneHotLabel:
......@@ -203,15 +203,15 @@ class SimpleCopyPaste(_RandomApplyTransform):
def _copy_paste(
self,
image: Any,
image: features.TensorImageType,
target: Dict[str, Any],
paste_image: Any,
paste_image: features.TensorImageType,
paste_target: Dict[str, Any],
random_selection: torch.Tensor,
blending: bool,
resize_interpolation: F.InterpolationMode,
antialias: Optional[bool],
) -> Tuple[Any, Dict[str, Any]]:
) -> Tuple[features.TensorImageType, Dict[str, Any]]:
paste_masks = paste_target["masks"].new_like(paste_target["masks"], paste_target["masks"][random_selection])
paste_boxes = paste_target["boxes"].new_like(paste_target["boxes"], paste_target["boxes"][random_selection])
......@@ -223,7 +223,7 @@ class SimpleCopyPaste(_RandomApplyTransform):
# This is something different to TF implementation we introduced here as
# originally the algorithm works on equal-sized data
# (for example, coming from LSJ data augmentations)
size1 = image.shape[-2:]
size1 = cast(List[int], image.shape[-2:])
size2 = paste_image.shape[-2:]
if size1 != size2:
paste_image = F.resize(paste_image, size=size1, interpolation=resize_interpolation, antialias=antialias)
......@@ -278,7 +278,9 @@ class SimpleCopyPaste(_RandomApplyTransform):
return image, out_target
def _extract_image_targets(self, flat_sample: List[Any]) -> Tuple[List[Any], List[Dict[str, Any]]]:
def _extract_image_targets(
self, flat_sample: List[Any]
) -> Tuple[List[features.TensorImageType], List[Dict[str, Any]]]:
# fetch all images, bboxes, masks and labels from unstructured input
# with List[image], List[BoundingBox], List[Mask], List[Label]
images, bboxes, masks, labels = [], [], [], []
......@@ -307,7 +309,10 @@ class SimpleCopyPaste(_RandomApplyTransform):
return images, targets
def _insert_outputs(
self, flat_sample: List[Any], output_images: List[Any], output_targets: List[Dict[str, Any]]
self,
flat_sample: List[Any],
output_images: List[features.TensorImageType],
output_targets: List[Dict[str, Any]],
) -> None:
c0, c1, c2, c3 = 0, 0, 0, 0
for i, obj in enumerate(flat_sample):
......
......@@ -9,7 +9,7 @@ from torchvision.prototype import features
from torchvision.prototype.transforms import AutoAugmentPolicy, functional as F, InterpolationMode, Transform
from torchvision.prototype.transforms.functional._meta import get_chw
from ._utils import _isinstance, _setup_fill_arg, FillType
from ._utils import _isinstance, _setup_fill_arg
K = TypeVar("K")
V = TypeVar("V")
......@@ -20,7 +20,7 @@ class _AutoAugmentBase(Transform):
self,
*,
interpolation: InterpolationMode = InterpolationMode.NEAREST,
fill: Optional[Union[FillType, Dict[Type, FillType]]] = None,
fill: Union[features.FillType, Dict[Type, features.FillType]] = None,
) -> None:
super().__init__()
self.interpolation = interpolation
......@@ -35,7 +35,7 @@ class _AutoAugmentBase(Transform):
self,
sample: Any,
unsupported_types: Tuple[Type, ...] = (features.BoundingBox, features.Mask),
) -> Tuple[int, Union[PIL.Image.Image, torch.Tensor, features.Image]]:
) -> Tuple[int, features.ImageType]:
sample_flat, _ = tree_flatten(sample)
images = []
for id, inpt in enumerate(sample_flat):
......@@ -59,12 +59,12 @@ class _AutoAugmentBase(Transform):
def _apply_image_transform(
self,
image: Union[torch.Tensor, PIL.Image.Image, features.Image],
image: features.ImageType,
transform_id: str,
magnitude: float,
interpolation: InterpolationMode,
fill: Union[Dict[Type, FillType], Dict[Type, None]],
) -> Any:
fill: Dict[Type, features.FillType],
) -> features.ImageType:
fill_ = fill[type(image)]
fill_ = F._geometry._convert_fill_arg(fill_)
......@@ -177,7 +177,7 @@ class AutoAugment(_AutoAugmentBase):
self,
policy: AutoAugmentPolicy = AutoAugmentPolicy.IMAGENET,
interpolation: InterpolationMode = InterpolationMode.NEAREST,
fill: Optional[Union[FillType, Dict[Type, FillType]]] = None,
fill: Union[features.FillType, Dict[Type, features.FillType]] = None,
) -> None:
super().__init__(interpolation=interpolation, fill=fill)
self.policy = policy
......@@ -337,7 +337,7 @@ class RandAugment(_AutoAugmentBase):
magnitude: int = 9,
num_magnitude_bins: int = 31,
interpolation: InterpolationMode = InterpolationMode.NEAREST,
fill: Optional[Union[FillType, Dict[Type, FillType]]] = None,
fill: Union[features.FillType, Dict[Type, features.FillType]] = None,
) -> None:
super().__init__(interpolation=interpolation, fill=fill)
self.num_ops = num_ops
......@@ -393,7 +393,7 @@ class TrivialAugmentWide(_AutoAugmentBase):
self,
num_magnitude_bins: int = 31,
interpolation: InterpolationMode = InterpolationMode.NEAREST,
fill: Optional[Union[FillType, Dict[Type, FillType]]] = None,
fill: Union[features.FillType, Dict[Type, features.FillType]] = None,
):
super().__init__(interpolation=interpolation, fill=fill)
self.num_magnitude_bins = num_magnitude_bins
......@@ -453,7 +453,7 @@ class AugMix(_AutoAugmentBase):
alpha: float = 1.0,
all_ops: bool = True,
interpolation: InterpolationMode = InterpolationMode.BILINEAR,
fill: Optional[Union[FillType, Dict[Type, FillType]]] = None,
fill: Union[features.FillType, Dict[Type, features.FillType]] = None,
) -> None:
super().__init__(interpolation=interpolation, fill=fill)
self._PARAMETER_MAX = 10
......
import collections.abc
from typing import Any, Dict, Optional, Sequence, Tuple, TypeVar, Union
from typing import Any, Dict, Optional, Sequence, Tuple, Union
import PIL.Image
import torch
......@@ -9,8 +9,6 @@ from torchvision.prototype.transforms import functional as F, Transform
from ._transform import _RandomApplyTransform
from ._utils import query_chw
T = TypeVar("T", features.Image, torch.Tensor, PIL.Image.Image)
class ColorJitter(Transform):
def __init__(
......@@ -112,7 +110,7 @@ class RandomPhotometricDistort(Transform):
channel_permutation=torch.randperm(num_channels) if torch.rand(()) < self.p else None,
)
def _permute_channels(self, inpt: Any, permutation: torch.Tensor) -> Any:
def _permute_channels(self, inpt: features.ImageType, permutation: torch.Tensor) -> features.ImageType:
if isinstance(inpt, PIL.Image.Image):
inpt = F.pil_to_tensor(inpt)
......@@ -125,9 +123,7 @@ class RandomPhotometricDistort(Transform):
return output
def _transform(
self, inpt: Union[torch.Tensor, features.Image, PIL.Image.Image], params: Dict[str, Any]
) -> Union[torch.Tensor, features.Image, PIL.Image.Image]:
def _transform(self, inpt: features.ImageType, params: Dict[str, Any]) -> features.ImageType:
if params["brightness"]:
inpt = F.adjust_brightness(
inpt, brightness_factor=ColorJitter._generate_value(self.brightness[0], self.brightness[1])
......
......@@ -11,7 +11,7 @@ from torchvision.transforms import functional as _F
from typing_extensions import Literal
from ._transform import _RandomApplyTransform
from ._utils import DType, query_chw
from ._utils import query_chw
class ToTensor(Transform):
......@@ -52,7 +52,7 @@ class Grayscale(Transform):
super().__init__()
self.num_output_channels = num_output_channels
def _transform(self, inpt: DType, params: Dict[str, Any]) -> DType:
def _transform(self, inpt: features.ImageType, params: Dict[str, Any]) -> features.ImageType:
output = _F.rgb_to_grayscale(inpt, num_output_channels=self.num_output_channels)
if isinstance(inpt, features.Image):
output = features.Image.new_like(inpt, output, color_space=features.ColorSpace.GRAY)
......@@ -81,7 +81,7 @@ class RandomGrayscale(_RandomApplyTransform):
num_input_channels, _, _ = query_chw(sample)
return dict(num_input_channels=num_input_channels)
def _transform(self, inpt: DType, params: Dict[str, Any]) -> DType:
def _transform(self, inpt: features.ImageType, params: Dict[str, Any]) -> features.ImageType:
output = _F.rgb_to_grayscale(inpt, num_output_channels=params["num_input_channels"])
if isinstance(inpt, features.Image):
output = features.Image.new_like(inpt, output, color_space=features.ColorSpace.GRAY)
......
......@@ -20,8 +20,6 @@ from ._utils import (
_setup_angle,
_setup_fill_arg,
_setup_size,
DType,
FillType,
has_all,
has_any,
query_bounding_box,
......@@ -179,7 +177,9 @@ class FiveCrop(Transform):
super().__init__()
self.size = _setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.")
def _transform(self, inpt: DType, params: Dict[str, Any]) -> Tuple[DType, DType, DType, DType, DType]:
def _transform(
self, inpt: features.ImageType, params: Dict[str, Any]
) -> Tuple[features.ImageType, features.ImageType, features.ImageType, features.ImageType, features.ImageType]:
return F.five_crop(inpt, self.size)
def forward(self, *inputs: Any) -> Any:
......@@ -200,7 +200,7 @@ class TenCrop(Transform):
self.size = _setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.")
self.vertical_flip = vertical_flip
def _transform(self, inpt: DType, params: Dict[str, Any]) -> List[DType]:
def _transform(self, inpt: features.ImageType, params: Dict[str, Any]) -> List[features.ImageType]:
return F.ten_crop(inpt, self.size, vertical_flip=self.vertical_flip)
def forward(self, *inputs: Any) -> Any:
......@@ -213,7 +213,7 @@ class Pad(Transform):
def __init__(
self,
padding: Union[int, Sequence[int]],
fill: Union[FillType, Dict[Type, FillType]] = 0,
fill: Union[features.FillType, Dict[Type, features.FillType]] = 0,
padding_mode: Literal["constant", "edge", "reflect", "symmetric"] = "constant",
) -> None:
super().__init__()
......@@ -240,7 +240,7 @@ class Pad(Transform):
class RandomZoomOut(_RandomApplyTransform):
def __init__(
self,
fill: Union[FillType, Dict[Type, FillType]] = 0,
fill: Union[features.FillType, Dict[Type, features.FillType]] = 0,
side_range: Sequence[float] = (1.0, 4.0),
p: float = 0.5,
) -> None:
......@@ -282,7 +282,7 @@ class RandomRotation(Transform):
degrees: Union[numbers.Number, Sequence],
interpolation: InterpolationMode = InterpolationMode.NEAREST,
expand: bool = False,
fill: Union[FillType, Dict[Type, FillType]] = 0,
fill: Union[features.FillType, Dict[Type, features.FillType]] = 0,
center: Optional[List[float]] = None,
) -> None:
super().__init__()
......@@ -322,7 +322,7 @@ class RandomAffine(Transform):
scale: Optional[Sequence[float]] = None,
shear: Optional[Union[float, Sequence[float]]] = None,
interpolation: InterpolationMode = InterpolationMode.NEAREST,
fill: Union[FillType, Dict[Type, FillType]] = 0,
fill: Union[features.FillType, Dict[Type, features.FillType]] = 0,
center: Optional[List[float]] = None,
) -> None:
super().__init__()
......@@ -401,7 +401,7 @@ class RandomCrop(Transform):
size: Union[int, Sequence[int]],
padding: Optional[Union[int, Sequence[int]]] = None,
pad_if_needed: bool = False,
fill: Union[FillType, Dict[Type, FillType]] = 0,
fill: Union[features.FillType, Dict[Type, features.FillType]] = 0,
padding_mode: Literal["constant", "edge", "reflect", "symmetric"] = "constant",
) -> None:
super().__init__()
......@@ -491,7 +491,7 @@ class RandomPerspective(_RandomApplyTransform):
def __init__(
self,
distortion_scale: float = 0.5,
fill: Union[FillType, Dict[Type, FillType]] = 0,
fill: Union[features.FillType, Dict[Type, features.FillType]] = 0,
interpolation: InterpolationMode = InterpolationMode.BILINEAR,
p: float = 0.5,
) -> None:
......@@ -567,7 +567,7 @@ class ElasticTransform(Transform):
self,
alpha: Union[float, Sequence[float]] = 50.0,
sigma: Union[float, Sequence[float]] = 5.0,
fill: Union[FillType, Dict[Type, FillType]] = 0,
fill: Union[features.FillType, Dict[Type, features.FillType]] = 0,
interpolation: InterpolationMode = InterpolationMode.BILINEAR,
) -> None:
super().__init__()
......@@ -780,7 +780,7 @@ class FixedSizeCrop(Transform):
def __init__(
self,
size: Union[int, Sequence[int]],
fill: Union[FillType, Dict[Type, FillType]] = 0,
fill: Union[features.FillType, Dict[Type, features.FillType]] = 0,
padding_mode: str = "constant",
) -> None:
super().__init__()
......
......@@ -28,9 +28,7 @@ class ConvertImageDtype(Transform):
super().__init__()
self.dtype = dtype
def _transform(
self, inpt: Union[torch.Tensor, features.Image], params: Dict[str, Any]
) -> Union[torch.Tensor, features.Image]:
def _transform(self, inpt: features.TensorImageType, params: Dict[str, Any]) -> features.TensorImageType:
output = F.convert_image_dtype(inpt, dtype=self.dtype)
return output if features.is_simple_tensor(inpt) else features.Image.new_like(inpt, output, dtype=self.dtype) # type: ignore[arg-type]
......@@ -56,9 +54,7 @@ class ConvertColorSpace(Transform):
self.copy = copy
def _transform(
self, inpt: Union[torch.Tensor, PIL.Image.Image, features._Feature], params: Dict[str, Any]
) -> Union[torch.Tensor, PIL.Image.Image, features._Feature]:
def _transform(self, inpt: features.ImageType, params: Dict[str, Any]) -> features.ImageType:
return F.convert_color_space(
inpt, color_space=self.color_space, old_color_space=self.old_color_space, copy=self.copy
)
......
......@@ -68,7 +68,7 @@ class LinearTransformation(Transform):
return super().forward(*inputs)
def _transform(self, inpt: Union[torch.Tensor, features._Feature], params: Dict[str, Any]) -> torch.Tensor:
def _transform(self, inpt: features.TensorImageType, params: Dict[str, Any]) -> torch.Tensor:
# Image instance after linear transformation is not Image anymore due to unknown data range
# Thus we will return Tensor for input Image
......@@ -101,7 +101,7 @@ class Normalize(Transform):
self.std = list(std)
self.inplace = inplace
def _transform(self, inpt: Union[torch.Tensor, features._Feature], params: Dict[str, Any]) -> torch.Tensor:
def _transform(self, inpt: features.TensorImageType, params: Dict[str, Any]) -> torch.Tensor:
return F.normalize(inpt, mean=self.mean, std=self.std, inplace=self.inplace)
def forward(self, *inpts: Any) -> Any:
......
import numbers
from collections import defaultdict
from typing import Any, Callable, Dict, Optional, Sequence, Tuple, Type, Union
from typing import Any, Callable, Dict, Sequence, Tuple, Type, Union
import PIL.Image
import torch
from torch.utils._pytree import tree_flatten
from torchvision._utils import sequence_to_str
from torchvision.prototype import features
from torchvision.prototype.features._feature import FillType
from torchvision.prototype.transforms.functional._meta import get_chw
from torchvision.transforms.transforms import _check_sequence_input, _setup_angle, _setup_size # noqa: F401
......@@ -16,12 +16,7 @@ from torchvision.transforms.transforms import _check_sequence_input, _setup_angl
from typing_extensions import Literal
# Type shortcuts:
DType = Union[torch.Tensor, PIL.Image.Image, features._Feature]
FillType = Union[int, float, Sequence[int], Sequence[float]]
def _check_fill_arg(fill: Optional[Union[FillType, Dict[Type, FillType]]]) -> None:
def _check_fill_arg(fill: Union[FillType, Dict[Type, FillType]]) -> None:
if isinstance(fill, dict):
for key, value in fill.items():
# Check key for type
......@@ -31,15 +26,13 @@ def _check_fill_arg(fill: Optional[Union[FillType, Dict[Type, FillType]]]) -> No
raise TypeError("Got inappropriate fill arg")
def _setup_fill_arg(
fill: Optional[Union[FillType, Dict[Type, FillType]]]
) -> Union[Dict[Type, FillType], Dict[Type, None]]:
def _setup_fill_arg(fill: Union[FillType, Dict[Type, FillType]]) -> Dict[Type, FillType]:
_check_fill_arg(fill)
if isinstance(fill, dict):
return fill
return defaultdict(lambda: fill) # type: ignore[return-value]
return defaultdict(lambda: fill) # type: ignore[return-value, arg-type]
def _check_padding_arg(padding: Union[int, Sequence[int]]) -> None:
......
......@@ -5,7 +5,6 @@ from torchvision.prototype import features
from torchvision.transforms import functional_tensor as _FT
from torchvision.transforms.functional import pil_to_tensor, to_pil_image
erase_image_tensor = _FT.erase
......@@ -19,14 +18,14 @@ def erase_image_pil(
def erase(
inpt: features.ImageType,
inpt: features.ImageTypeJIT,
i: int,
j: int,
h: int,
w: int,
v: torch.Tensor,
inplace: bool = False,
) -> features.ImageType:
) -> features.ImageTypeJIT:
if isinstance(inpt, torch.Tensor):
output = erase_image_tensor(inpt, i=i, j=j, h=h, w=w, v=v, inplace=inplace)
if not torch.jit.is_scripting() and isinstance(inpt, features.Image):
......
......@@ -2,12 +2,11 @@ import torch
from torchvision.prototype import features
from torchvision.transforms import functional_pil as _FP, functional_tensor as _FT
adjust_brightness_image_tensor = _FT.adjust_brightness
adjust_brightness_image_pil = _FP.adjust_brightness
def adjust_brightness(inpt: features.DType, brightness_factor: float) -> features.DType:
def adjust_brightness(inpt: features.InputTypeJIT, brightness_factor: float) -> features.InputTypeJIT:
if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, features._Feature)):
return adjust_brightness_image_tensor(inpt, brightness_factor=brightness_factor)
elif isinstance(inpt, features._Feature):
......@@ -20,7 +19,7 @@ adjust_saturation_image_tensor = _FT.adjust_saturation
adjust_saturation_image_pil = _FP.adjust_saturation
def adjust_saturation(inpt: features.DType, saturation_factor: float) -> features.DType:
def adjust_saturation(inpt: features.InputTypeJIT, saturation_factor: float) -> features.InputTypeJIT:
if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, features._Feature)):
return adjust_saturation_image_tensor(inpt, saturation_factor=saturation_factor)
elif isinstance(inpt, features._Feature):
......@@ -33,7 +32,7 @@ adjust_contrast_image_tensor = _FT.adjust_contrast
adjust_contrast_image_pil = _FP.adjust_contrast
def adjust_contrast(inpt: features.DType, contrast_factor: float) -> features.DType:
def adjust_contrast(inpt: features.InputTypeJIT, contrast_factor: float) -> features.InputTypeJIT:
if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, features._Feature)):
return adjust_contrast_image_tensor(inpt, contrast_factor=contrast_factor)
elif isinstance(inpt, features._Feature):
......@@ -46,7 +45,7 @@ adjust_sharpness_image_tensor = _FT.adjust_sharpness
adjust_sharpness_image_pil = _FP.adjust_sharpness
def adjust_sharpness(inpt: features.DType, sharpness_factor: float) -> features.DType:
def adjust_sharpness(inpt: features.InputTypeJIT, sharpness_factor: float) -> features.InputTypeJIT:
if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, features._Feature)):
return adjust_sharpness_image_tensor(inpt, sharpness_factor=sharpness_factor)
elif isinstance(inpt, features._Feature):
......@@ -59,7 +58,7 @@ adjust_hue_image_tensor = _FT.adjust_hue
adjust_hue_image_pil = _FP.adjust_hue
def adjust_hue(inpt: features.DType, hue_factor: float) -> features.DType:
def adjust_hue(inpt: features.InputTypeJIT, hue_factor: float) -> features.InputTypeJIT:
if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, features._Feature)):
return adjust_hue_image_tensor(inpt, hue_factor=hue_factor)
elif isinstance(inpt, features._Feature):
......@@ -72,7 +71,7 @@ adjust_gamma_image_tensor = _FT.adjust_gamma
adjust_gamma_image_pil = _FP.adjust_gamma
def adjust_gamma(inpt: features.DType, gamma: float, gain: float = 1) -> features.DType:
def adjust_gamma(inpt: features.InputTypeJIT, gamma: float, gain: float = 1) -> features.InputTypeJIT:
if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, features._Feature)):
return adjust_gamma_image_tensor(inpt, gamma=gamma, gain=gain)
elif isinstance(inpt, features._Feature):
......@@ -85,7 +84,7 @@ posterize_image_tensor = _FT.posterize
posterize_image_pil = _FP.posterize
def posterize(inpt: features.DType, bits: int) -> features.DType:
def posterize(inpt: features.InputTypeJIT, bits: int) -> features.InputTypeJIT:
if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, features._Feature)):
return posterize_image_tensor(inpt, bits=bits)
elif isinstance(inpt, features._Feature):
......@@ -98,7 +97,7 @@ solarize_image_tensor = _FT.solarize
solarize_image_pil = _FP.solarize
def solarize(inpt: features.DType, threshold: float) -> features.DType:
def solarize(inpt: features.InputTypeJIT, threshold: float) -> features.InputTypeJIT:
if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, features._Feature)):
return solarize_image_tensor(inpt, threshold=threshold)
elif isinstance(inpt, features._Feature):
......@@ -111,7 +110,7 @@ autocontrast_image_tensor = _FT.autocontrast
autocontrast_image_pil = _FP.autocontrast
def autocontrast(inpt: features.DType) -> features.DType:
def autocontrast(inpt: features.InputTypeJIT) -> features.InputTypeJIT:
if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, features._Feature)):
return autocontrast_image_tensor(inpt)
elif isinstance(inpt, features._Feature):
......@@ -124,7 +123,7 @@ equalize_image_tensor = _FT.equalize
equalize_image_pil = _FP.equalize
def equalize(inpt: features.DType) -> features.DType:
def equalize(inpt: features.InputTypeJIT) -> features.InputTypeJIT:
if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, features._Feature)):
return equalize_image_tensor(inpt)
elif isinstance(inpt, features._Feature):
......@@ -137,7 +136,7 @@ invert_image_tensor = _FT.invert
invert_image_pil = _FP.invert
def invert(inpt: features.DType) -> features.DType:
def invert(inpt: features.InputTypeJIT) -> features.InputTypeJIT:
if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, features._Feature)):
return invert_image_tensor(inpt)
elif isinstance(inpt, features._Feature):
......
......@@ -8,11 +8,6 @@ from torchvision.prototype import features
from torchvision.transforms import functional as _F
# Due to torch.jit.script limitation we keep LegacyImageType as torch.Tensor
# instead of Union[torch.Tensor, PIL.Image.Image]
LegacyImageType = torch.Tensor
@torch.jit.unused
def to_grayscale(inpt: PIL.Image.Image, num_output_channels: int = 1) -> PIL.Image.Image:
call = ", num_output_channels=3" if num_output_channels == 3 else ""
......@@ -27,7 +22,7 @@ def to_grayscale(inpt: PIL.Image.Image, num_output_channels: int = 1) -> PIL.Ima
return _F.to_grayscale(inpt, num_output_channels=num_output_channels)
def rgb_to_grayscale(inpt: LegacyImageType, num_output_channels: int = 1) -> LegacyImageType:
def rgb_to_grayscale(inpt: features.LegacyImageTypeJIT, num_output_channels: int = 1) -> features.LegacyImageTypeJIT:
old_color_space = (
features._image._from_tensor_shape(inpt.shape) # type: ignore[arg-type]
if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, features.Image))
......@@ -61,7 +56,7 @@ def to_tensor(inpt: Any) -> torch.Tensor:
return _F.to_tensor(inpt)
def get_image_size(inpt: features.ImageType) -> List[int]:
def get_image_size(inpt: features.ImageTypeJIT) -> List[int]:
warnings.warn(
"The function `get_image_size(...)` is deprecated and will be removed in a future release. "
"Instead, please use `get_spatial_size(...)` which returns `[h, w]` instead of `[w, h]`."
......
......@@ -18,7 +18,6 @@ from torchvision.transforms.functional_tensor import _parse_pad_padding
from ._meta import convert_format_bounding_box, get_dimensions_image_pil, get_dimensions_image_tensor
horizontal_flip_image_tensor = _FT.hflip
horizontal_flip_image_pil = _FP.hflip
......@@ -43,7 +42,7 @@ def horizontal_flip_bounding_box(
).view(shape)
def horizontal_flip(inpt: features.DType) -> features.DType:
def horizontal_flip(inpt: features.InputTypeJIT) -> features.InputTypeJIT:
if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, features._Feature)):
return horizontal_flip_image_tensor(inpt)
elif isinstance(inpt, features._Feature):
......@@ -76,7 +75,7 @@ def vertical_flip_bounding_box(
).view(shape)
def vertical_flip(inpt: features.DType) -> features.DType:
def vertical_flip(inpt: features.InputTypeJIT) -> features.InputTypeJIT:
if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, features._Feature)):
return vertical_flip_image_tensor(inpt)
elif isinstance(inpt, features._Feature):
......@@ -153,12 +152,12 @@ def resize_bounding_box(
def resize(
inpt: features.DType,
inpt: features.InputTypeJIT,
size: List[int],
interpolation: InterpolationMode = InterpolationMode.BILINEAR,
max_size: Optional[int] = None,
antialias: Optional[bool] = None,
) -> features.DType:
) -> features.InputTypeJIT:
if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, features._Feature)):
antialias = False if antialias is None else antialias
return resize_image_tensor(inpt, size, interpolation=interpolation, max_size=max_size, antialias=antialias)
......@@ -228,7 +227,7 @@ def affine_image_tensor(
scale: float,
shear: List[float],
interpolation: InterpolationMode = InterpolationMode.NEAREST,
fill: Optional[Union[int, float, List[float]]] = None,
fill: features.FillTypeJIT = None,
center: Optional[List[float]] = None,
) -> torch.Tensor:
if img.numel() == 0:
......@@ -260,7 +259,7 @@ def affine_image_pil(
scale: float,
shear: List[float],
interpolation: InterpolationMode = InterpolationMode.NEAREST,
fill: Optional[Union[int, float, List[float]]] = None,
fill: features.FillTypeJIT = None,
center: Optional[List[float]] = None,
) -> PIL.Image.Image:
angle, translate, shear, center = _affine_parse_args(angle, translate, scale, shear, interpolation, center)
......@@ -378,7 +377,7 @@ def affine_mask(
translate: List[float],
scale: float,
shear: List[float],
fill: Optional[Union[int, float, List[float]]] = None,
fill: features.FillTypeJIT = None,
center: Optional[List[float]] = None,
) -> torch.Tensor:
if mask.ndim < 3:
......@@ -404,9 +403,7 @@ def affine_mask(
return output
def _convert_fill_arg(
fill: Optional[Union[int, float, Sequence[int], Sequence[float]]]
) -> Optional[Union[int, float, List[float]]]:
def _convert_fill_arg(fill: features.FillType) -> features.FillTypeJIT:
# Fill = 0 is not equivalent to None, https://github.com/pytorch/vision/issues/6517
# So, we can't reassign fill to 0
# if fill is None:
......@@ -421,15 +418,15 @@ def _convert_fill_arg(
def affine(
inpt: features.DType,
inpt: features.InputTypeJIT,
angle: float,
translate: List[float],
scale: float,
shear: List[float],
interpolation: InterpolationMode = InterpolationMode.NEAREST,
fill: Optional[Union[int, float, List[float]]] = None,
fill: features.FillTypeJIT = None,
center: Optional[List[float]] = None,
) -> features.DType:
) -> features.InputTypeJIT:
if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, features._Feature)):
return affine_image_tensor(
inpt,
......@@ -463,7 +460,7 @@ def rotate_image_tensor(
angle: float,
interpolation: InterpolationMode = InterpolationMode.NEAREST,
expand: bool = False,
fill: Optional[Union[int, float, List[float]]] = None,
fill: features.FillTypeJIT = None,
center: Optional[List[float]] = None,
) -> torch.Tensor:
num_channels, height, width = img.shape[-3:]
......@@ -502,7 +499,7 @@ def rotate_image_pil(
angle: float,
interpolation: InterpolationMode = InterpolationMode.NEAREST,
expand: bool = False,
fill: Optional[Union[int, float, List[float]]] = None,
fill: features.FillTypeJIT = None,
center: Optional[List[float]] = None,
) -> PIL.Image.Image:
if center is not None and expand:
......@@ -542,7 +539,7 @@ def rotate_mask(
mask: torch.Tensor,
angle: float,
expand: bool = False,
fill: Optional[Union[int, float, List[float]]] = None,
fill: features.FillTypeJIT = None,
center: Optional[List[float]] = None,
) -> torch.Tensor:
if mask.ndim < 3:
......@@ -567,13 +564,13 @@ def rotate_mask(
def rotate(
inpt: features.DType,
inpt: features.InputTypeJIT,
angle: float,
interpolation: InterpolationMode = InterpolationMode.NEAREST,
expand: bool = False,
fill: Optional[Union[int, float, List[float]]] = None,
fill: features.FillTypeJIT = None,
center: Optional[List[float]] = None,
) -> features.DType:
) -> features.InputTypeJIT:
if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, features._Feature)):
return rotate_image_tensor(inpt, angle, interpolation=interpolation, expand=expand, fill=fill, center=center)
elif isinstance(inpt, features._Feature):
......@@ -588,7 +585,7 @@ pad_image_pil = _FP.pad
def pad_image_tensor(
img: torch.Tensor,
padding: Union[int, List[int]],
fill: Optional[Union[int, float, List[float]]] = None,
fill: features.FillTypeJIT = None,
padding_mode: str = "constant",
) -> torch.Tensor:
if fill is None:
......@@ -652,7 +649,7 @@ def pad_mask(
mask: torch.Tensor,
padding: Union[int, List[int]],
padding_mode: str = "constant",
fill: Optional[Union[int, float, List[float]]] = None,
fill: features.FillTypeJIT = None,
) -> torch.Tensor:
if fill is None:
fill = 0
......@@ -698,11 +695,11 @@ def pad_bounding_box(
def pad(
inpt: features.DType,
inpt: features.InputTypeJIT,
padding: Union[int, List[int]],
fill: Optional[Union[int, float, List[float]]] = None,
fill: features.FillTypeJIT = None,
padding_mode: str = "constant",
) -> features.DType:
) -> features.InputTypeJIT:
if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, features._Feature)):
return pad_image_tensor(inpt, padding, fill=fill, padding_mode=padding_mode)
......@@ -739,7 +736,7 @@ def crop_mask(mask: torch.Tensor, top: int, left: int, height: int, width: int)
return crop_image_tensor(mask, top, left, height, width)
def crop(inpt: features.DType, top: int, left: int, height: int, width: int) -> features.DType:
def crop(inpt: features.InputTypeJIT, top: int, left: int, height: int, width: int) -> features.InputTypeJIT:
if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, features._Feature)):
return crop_image_tensor(inpt, top, left, height, width)
elif isinstance(inpt, features._Feature):
......@@ -752,7 +749,7 @@ def perspective_image_tensor(
img: torch.Tensor,
perspective_coeffs: List[float],
interpolation: InterpolationMode = InterpolationMode.BILINEAR,
fill: Optional[Union[int, float, List[float]]] = None,
fill: features.FillTypeJIT = None,
) -> torch.Tensor:
return _FT.perspective(img, perspective_coeffs, interpolation=interpolation.value, fill=fill)
......@@ -762,7 +759,7 @@ def perspective_image_pil(
img: PIL.Image.Image,
perspective_coeffs: List[float],
interpolation: InterpolationMode = InterpolationMode.BICUBIC,
fill: Optional[Union[int, float, List[float]]] = None,
fill: features.FillTypeJIT = None,
) -> PIL.Image.Image:
return _FP.perspective(img, perspective_coeffs, interpolation=pil_modes_mapping[interpolation], fill=fill)
......@@ -855,7 +852,7 @@ def perspective_bounding_box(
def perspective_mask(
mask: torch.Tensor,
perspective_coeffs: List[float],
fill: Optional[Union[int, float, List[float]]] = None,
fill: features.FillTypeJIT = None,
) -> torch.Tensor:
if mask.ndim < 3:
mask = mask.unsqueeze(0)
......@@ -874,11 +871,11 @@ def perspective_mask(
def perspective(
inpt: features.DType,
inpt: features.InputTypeJIT,
perspective_coeffs: List[float],
interpolation: InterpolationMode = InterpolationMode.BILINEAR,
fill: Optional[Union[int, float, List[float]]] = None,
) -> features.DType:
fill: features.FillTypeJIT = None,
) -> features.InputTypeJIT:
if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, features._Feature)):
return perspective_image_tensor(inpt, perspective_coeffs, interpolation=interpolation, fill=fill)
elif isinstance(inpt, features._Feature):
......@@ -891,7 +888,7 @@ def elastic_image_tensor(
img: torch.Tensor,
displacement: torch.Tensor,
interpolation: InterpolationMode = InterpolationMode.BILINEAR,
fill: Optional[Union[int, float, List[float]]] = None,
fill: features.FillTypeJIT = None,
) -> torch.Tensor:
return _FT.elastic_transform(img, displacement, interpolation=interpolation.value, fill=fill)
......@@ -901,7 +898,7 @@ def elastic_image_pil(
img: PIL.Image.Image,
displacement: torch.Tensor,
interpolation: InterpolationMode = InterpolationMode.BILINEAR,
fill: Optional[Union[int, float, List[float]]] = None,
fill: features.FillTypeJIT = None,
) -> PIL.Image.Image:
t_img = pil_to_tensor(img)
output = elastic_image_tensor(t_img, displacement, interpolation=interpolation, fill=fill)
......@@ -951,7 +948,7 @@ def elastic_bounding_box(
def elastic_mask(
mask: torch.Tensor,
displacement: torch.Tensor,
fill: Optional[Union[int, float, List[float]]] = None,
fill: features.FillTypeJIT = None,
) -> torch.Tensor:
if mask.ndim < 3:
mask = mask.unsqueeze(0)
......@@ -968,11 +965,11 @@ def elastic_mask(
def elastic(
inpt: features.DType,
inpt: features.InputTypeJIT,
displacement: torch.Tensor,
interpolation: InterpolationMode = InterpolationMode.BILINEAR,
fill: Optional[Union[int, float, List[float]]] = None,
) -> features.DType:
fill: features.FillTypeJIT = None,
) -> features.InputTypeJIT:
if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, features._Feature)):
return elastic_image_tensor(inpt, displacement, interpolation=interpolation, fill=fill)
elif isinstance(inpt, features._Feature):
......@@ -1069,7 +1066,7 @@ def center_crop_mask(mask: torch.Tensor, output_size: List[int]) -> torch.Tensor
return output
def center_crop(inpt: features.DType, output_size: List[int]) -> features.DType:
def center_crop(inpt: features.InputTypeJIT, output_size: List[int]) -> features.InputTypeJIT:
if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, features._Feature)):
return center_crop_image_tensor(inpt, output_size)
elif isinstance(inpt, features._Feature):
......@@ -1132,7 +1129,7 @@ def resized_crop_mask(
def resized_crop(
inpt: features.DType,
inpt: features.InputTypeJIT,
top: int,
left: int,
height: int,
......@@ -1140,7 +1137,7 @@ def resized_crop(
size: List[int],
interpolation: InterpolationMode = InterpolationMode.BILINEAR,
antialias: Optional[bool] = None,
) -> features.DType:
) -> features.InputTypeJIT:
if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, features._Feature)):
antialias = False if antialias is None else antialias
return resized_crop_image_tensor(
......@@ -1205,9 +1202,11 @@ def five_crop_image_pil(
def five_crop(
inpt: features.ImageType, size: List[int]
) -> Tuple[features.ImageType, features.ImageType, features.ImageType, features.ImageType, features.ImageType]:
# TODO: consider breaking BC here to return List[features.ImageType] to align this op with `ten_crop`
inpt: features.ImageTypeJIT, size: List[int]
) -> Tuple[
features.ImageTypeJIT, features.ImageTypeJIT, features.ImageTypeJIT, features.ImageTypeJIT, features.ImageTypeJIT
]:
# TODO: consider breaking BC here to return List[features.ImageTypeJIT] to align this op with `ten_crop`
if isinstance(inpt, torch.Tensor):
output = five_crop_image_tensor(inpt, size)
if not torch.jit.is_scripting() and isinstance(inpt, features.Image):
......@@ -1244,7 +1243,7 @@ def ten_crop_image_pil(img: PIL.Image.Image, size: List[int], vertical_flip: boo
return [tl, tr, bl, br, center, tl_flip, tr_flip, bl_flip, br_flip, center_flip]
def ten_crop(inpt: features.ImageType, size: List[int], vertical_flip: bool = False) -> List[features.ImageType]:
def ten_crop(inpt: features.ImageTypeJIT, size: List[int], vertical_flip: bool = False) -> List[features.ImageTypeJIT]:
if isinstance(inpt, torch.Tensor):
output = ten_crop_image_tensor(inpt, size, vertical_flip=vertical_flip)
if not torch.jit.is_scripting() and isinstance(inpt, features.Image):
......
......@@ -11,7 +11,7 @@ get_dimensions_image_pil = _FP.get_dimensions
# TODO: Should this be prefixed with `_` similar to other methods that don't get exposed by init?
def get_chw(image: features.ImageType) -> Tuple[int, int, int]:
def get_chw(image: features.ImageTypeJIT) -> Tuple[int, int, int]:
if isinstance(image, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(image, features.Image)):
channels, height, width = get_dimensions_image_tensor(image)
elif isinstance(image, features.Image):
......@@ -29,11 +29,11 @@ def get_chw(image: features.ImageType) -> Tuple[int, int, int]:
# detailed above.
def get_dimensions(image: features.ImageType) -> List[int]:
def get_dimensions(image: features.ImageTypeJIT) -> List[int]:
return list(get_chw(image))
def get_num_channels(image: features.ImageType) -> int:
def get_num_channels(image: features.ImageTypeJIT) -> int:
num_channels, *_ = get_chw(image)
return num_channels
......@@ -43,7 +43,7 @@ def get_num_channels(image: features.ImageType) -> int:
get_image_num_channels = get_num_channels
def get_spatial_size(image: features.ImageType) -> List[int]:
def get_spatial_size(image: features.ImageTypeJIT) -> List[int]:
_, *size = get_chw(image)
return size
......@@ -208,11 +208,11 @@ def convert_color_space_image_pil(
def convert_color_space(
inpt: features.ImageType,
inpt: features.ImageTypeJIT,
color_space: ColorSpace,
old_color_space: Optional[ColorSpace] = None,
copy: bool = True,
) -> features.ImageType:
) -> features.ImageTypeJIT:
if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, features.Image)):
if old_color_space is None:
raise RuntimeError(
......@@ -225,4 +225,4 @@ def convert_color_space(
elif isinstance(inpt, features.Image):
return inpt.to_color_space(color_space, copy=copy)
else:
return cast(features.ImageType, convert_color_space_image_pil(inpt, color_space, copy=copy))
return cast(features.ImageTypeJIT, convert_color_space_image_pil(inpt, color_space, copy=copy))
......@@ -6,16 +6,12 @@ from torchvision.prototype import features
from torchvision.transforms import functional_tensor as _FT
from torchvision.transforms.functional import pil_to_tensor, to_pil_image
# Due to torch.jit.script limitation we keep TensorImageType as torch.Tensor
# instead of Union[torch.Tensor, features.Image]
TensorImageType = torch.Tensor
normalize_image_tensor = _FT.normalize
def normalize(inpt: TensorImageType, mean: List[float], std: List[float], inplace: bool = False) -> torch.Tensor:
def normalize(
inpt: features.TensorImageTypeJIT, mean: List[float], std: List[float], inplace: bool = False
) -> torch.Tensor:
if not isinstance(inpt, torch.Tensor):
raise TypeError(f"img should be Tensor Image. Got {type(inpt)}")
else:
......@@ -62,7 +58,9 @@ def gaussian_blur_image_pil(
return to_pil_image(output, mode=img.mode)
def gaussian_blur(inpt: features.DType, kernel_size: List[int], sigma: Optional[List[float]] = None) -> features.DType:
def gaussian_blur(
inpt: features.InputTypeJIT, kernel_size: List[int], sigma: Optional[List[float]] = None
) -> features.InputTypeJIT:
if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, features._Feature)):
return gaussian_blur_image_tensor(inpt, kernel_size=kernel_size, sigma=sigma)
elif isinstance(inpt, features._Feature):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment