Unverified Commit f725901d authored by Vasilis Vryniotis's avatar Vasilis Vryniotis Committed by GitHub
Browse files

[prototype] Align and Clean up transform types (#6627)



* Align and Clean up transform types

* Move type definitions to `_utils.py`

* fixing error message on tests

* Apply code review suggestions
Co-authored-by: default avatarvfdev <vfdev.5@gmail.com>

* Centralizing types and switching to always getting dicts.

* Fixing linter

* Refactoring typing definitions.

* Remove relative imports.

* Reuse type.

* Temporarily remove the TorchData tests.

* Restore the TorchData tests.
Co-authored-by: default avatarvfdev <vfdev.5@gmail.com>
parent 6b1646ca
...@@ -132,7 +132,7 @@ class TestSmoke: ...@@ -132,7 +132,7 @@ class TestSmoke:
transform(input_copy) transform(input_copy)
# Check if we raise an error if sample contains bbox or mask or label # Check if we raise an error if sample contains bbox or mask or label
err_msg = "does not support bounding boxes, masks and plain labels" err_msg = "does not support PIL images, bounding boxes, masks and plain labels"
input_copy = dict(input) input_copy = dict(input)
for unsup_data in [ for unsup_data in [
make_label(), make_label(),
......
from ._bounding_box import BoundingBox, BoundingBoxFormat from ._bounding_box import BoundingBox, BoundingBoxFormat
from ._encoded import EncodedData, EncodedImage, EncodedVideo from ._encoded import EncodedData, EncodedImage, EncodedVideo
from ._feature import _Feature, DType, is_simple_tensor from ._feature import _Feature, FillType, FillTypeJIT, InputType, InputTypeJIT, is_simple_tensor
from ._image import ColorSpace, Image, ImageType from ._image import (
ColorSpace,
Image,
ImageType,
ImageTypeJIT,
LegacyImageType,
LegacyImageTypeJIT,
TensorImageType,
TensorImageTypeJIT,
)
from ._label import Label, OneHotLabel from ._label import Label, OneHotLabel
from ._mask import Mask from ._mask import Mask
...@@ -6,7 +6,7 @@ import torch ...@@ -6,7 +6,7 @@ import torch
from torchvision._utils import StrEnum from torchvision._utils import StrEnum
from torchvision.transforms import InterpolationMode # TODO: this needs to be moved out of transforms from torchvision.transforms import InterpolationMode # TODO: this needs to be moved out of transforms
from ._feature import _Feature from ._feature import _Feature, FillTypeJIT
class BoundingBoxFormat(StrEnum): class BoundingBoxFormat(StrEnum):
...@@ -115,7 +115,7 @@ class BoundingBox(_Feature): ...@@ -115,7 +115,7 @@ class BoundingBox(_Feature):
def pad( def pad(
self, self,
padding: Union[int, Sequence[int]], padding: Union[int, Sequence[int]],
fill: Optional[Union[int, float, List[float]]] = None, fill: FillTypeJIT = None,
padding_mode: str = "constant", padding_mode: str = "constant",
) -> BoundingBox: ) -> BoundingBox:
# This cast does Sequence[int] -> List[int] and is required to make mypy happy # This cast does Sequence[int] -> List[int] and is required to make mypy happy
...@@ -137,7 +137,7 @@ class BoundingBox(_Feature): ...@@ -137,7 +137,7 @@ class BoundingBox(_Feature):
angle: float, angle: float,
interpolation: InterpolationMode = InterpolationMode.NEAREST, interpolation: InterpolationMode = InterpolationMode.NEAREST,
expand: bool = False, expand: bool = False,
fill: Optional[Union[int, float, List[float]]] = None, fill: FillTypeJIT = None,
center: Optional[List[float]] = None, center: Optional[List[float]] = None,
) -> BoundingBox: ) -> BoundingBox:
output = self._F.rotate_bounding_box( output = self._F.rotate_bounding_box(
...@@ -165,7 +165,7 @@ class BoundingBox(_Feature): ...@@ -165,7 +165,7 @@ class BoundingBox(_Feature):
scale: float, scale: float,
shear: List[float], shear: List[float],
interpolation: InterpolationMode = InterpolationMode.NEAREST, interpolation: InterpolationMode = InterpolationMode.NEAREST,
fill: Optional[Union[int, float, List[float]]] = None, fill: FillTypeJIT = None,
center: Optional[List[float]] = None, center: Optional[List[float]] = None,
) -> BoundingBox: ) -> BoundingBox:
output = self._F.affine_bounding_box( output = self._F.affine_bounding_box(
...@@ -184,7 +184,7 @@ class BoundingBox(_Feature): ...@@ -184,7 +184,7 @@ class BoundingBox(_Feature):
self, self,
perspective_coeffs: List[float], perspective_coeffs: List[float],
interpolation: InterpolationMode = InterpolationMode.BILINEAR, interpolation: InterpolationMode = InterpolationMode.BILINEAR,
fill: Optional[Union[int, float, List[float]]] = None, fill: FillTypeJIT = None,
) -> BoundingBox: ) -> BoundingBox:
output = self._F.perspective_bounding_box(self, self.format, perspective_coeffs) output = self._F.perspective_bounding_box(self, self.format, perspective_coeffs)
return BoundingBox.new_like(self, output, dtype=output.dtype) return BoundingBox.new_like(self, output, dtype=output.dtype)
...@@ -193,7 +193,7 @@ class BoundingBox(_Feature): ...@@ -193,7 +193,7 @@ class BoundingBox(_Feature):
self, self,
displacement: torch.Tensor, displacement: torch.Tensor,
interpolation: InterpolationMode = InterpolationMode.BILINEAR, interpolation: InterpolationMode = InterpolationMode.BILINEAR,
fill: Optional[Union[int, float, List[float]]] = None, fill: FillTypeJIT = None,
) -> BoundingBox: ) -> BoundingBox:
output = self._F.elastic_bounding_box(self, self.format, displacement) output = self._F.elastic_bounding_box(self, self.format, displacement)
return BoundingBox.new_like(self, output, dtype=output.dtype) return BoundingBox.new_like(self, output, dtype=output.dtype)
...@@ -3,16 +3,14 @@ from __future__ import annotations ...@@ -3,16 +3,14 @@ from __future__ import annotations
from types import ModuleType from types import ModuleType
from typing import Any, Callable, cast, List, Mapping, Optional, Sequence, Tuple, Type, TypeVar, Union from typing import Any, Callable, cast, List, Mapping, Optional, Sequence, Tuple, Type, TypeVar, Union
import PIL.Image
import torch import torch
from torch._C import _TensorBase, DisableTorchFunction from torch._C import _TensorBase, DisableTorchFunction
from torchvision.transforms import InterpolationMode from torchvision.transforms import InterpolationMode
F = TypeVar("F", bound="_Feature") F = TypeVar("F", bound="_Feature")
FillType = Union[int, float, Sequence[int], Sequence[float], None]
FillTypeJIT = Union[int, float, List[float], None]
# Due to torch.jit.script limitation we keep DType as torch.Tensor
# instead of Union[torch.Tensor, PIL.Image.Image, features._Feature]
DType = torch.Tensor
def is_simple_tensor(inpt: Any) -> bool: def is_simple_tensor(inpt: Any) -> bool:
...@@ -154,7 +152,7 @@ class _Feature(torch.Tensor): ...@@ -154,7 +152,7 @@ class _Feature(torch.Tensor):
def pad( def pad(
self, self,
padding: Union[int, List[int]], padding: Union[int, List[int]],
fill: Optional[Union[int, float, List[float]]] = None, fill: FillTypeJIT = None,
padding_mode: str = "constant", padding_mode: str = "constant",
) -> _Feature: ) -> _Feature:
return self return self
...@@ -164,7 +162,7 @@ class _Feature(torch.Tensor): ...@@ -164,7 +162,7 @@ class _Feature(torch.Tensor):
angle: float, angle: float,
interpolation: InterpolationMode = InterpolationMode.NEAREST, interpolation: InterpolationMode = InterpolationMode.NEAREST,
expand: bool = False, expand: bool = False,
fill: Optional[Union[int, float, List[float]]] = None, fill: FillTypeJIT = None,
center: Optional[List[float]] = None, center: Optional[List[float]] = None,
) -> _Feature: ) -> _Feature:
return self return self
...@@ -176,7 +174,7 @@ class _Feature(torch.Tensor): ...@@ -176,7 +174,7 @@ class _Feature(torch.Tensor):
scale: float, scale: float,
shear: List[float], shear: List[float],
interpolation: InterpolationMode = InterpolationMode.NEAREST, interpolation: InterpolationMode = InterpolationMode.NEAREST,
fill: Optional[Union[int, float, List[float]]] = None, fill: FillTypeJIT = None,
center: Optional[List[float]] = None, center: Optional[List[float]] = None,
) -> _Feature: ) -> _Feature:
return self return self
...@@ -185,7 +183,7 @@ class _Feature(torch.Tensor): ...@@ -185,7 +183,7 @@ class _Feature(torch.Tensor):
self, self,
perspective_coeffs: List[float], perspective_coeffs: List[float],
interpolation: InterpolationMode = InterpolationMode.BILINEAR, interpolation: InterpolationMode = InterpolationMode.BILINEAR,
fill: Optional[Union[int, float, List[float]]] = None, fill: FillTypeJIT = None,
) -> _Feature: ) -> _Feature:
return self return self
...@@ -193,7 +191,7 @@ class _Feature(torch.Tensor): ...@@ -193,7 +191,7 @@ class _Feature(torch.Tensor):
self, self,
displacement: torch.Tensor, displacement: torch.Tensor,
interpolation: InterpolationMode = InterpolationMode.BILINEAR, interpolation: InterpolationMode = InterpolationMode.BILINEAR,
fill: Optional[Union[int, float, List[float]]] = None, fill: FillTypeJIT = None,
) -> _Feature: ) -> _Feature:
return self return self
...@@ -232,3 +230,7 @@ class _Feature(torch.Tensor): ...@@ -232,3 +230,7 @@ class _Feature(torch.Tensor):
def gaussian_blur(self, kernel_size: List[int], sigma: Optional[List[float]] = None) -> _Feature: def gaussian_blur(self, kernel_size: List[int], sigma: Optional[List[float]] = None) -> _Feature:
return self return self
InputType = Union[torch.Tensor, PIL.Image.Image, _Feature]
InputTypeJIT = torch.Tensor
...@@ -3,18 +3,14 @@ from __future__ import annotations ...@@ -3,18 +3,14 @@ from __future__ import annotations
import warnings import warnings
from typing import Any, cast, List, Optional, Tuple, Union from typing import Any, cast, List, Optional, Tuple, Union
import PIL.Image
import torch import torch
from torchvision._utils import StrEnum from torchvision._utils import StrEnum
from torchvision.transforms.functional import InterpolationMode, to_pil_image from torchvision.transforms.functional import InterpolationMode, to_pil_image
from torchvision.utils import draw_bounding_boxes, make_grid from torchvision.utils import draw_bounding_boxes, make_grid
from ._bounding_box import BoundingBox from ._bounding_box import BoundingBox
from ._feature import _Feature from ._feature import _Feature, FillTypeJIT
# Due to torch.jit.script limitation we keep ImageType as torch.Tensor
# instead of Union[torch.Tensor, PIL.Image.Image, features.Image]
ImageType = torch.Tensor
class ColorSpace(StrEnum): class ColorSpace(StrEnum):
...@@ -181,7 +177,7 @@ class Image(_Feature): ...@@ -181,7 +177,7 @@ class Image(_Feature):
def pad( def pad(
self, self,
padding: Union[int, List[int]], padding: Union[int, List[int]],
fill: Optional[Union[int, float, List[float]]] = None, fill: FillTypeJIT = None,
padding_mode: str = "constant", padding_mode: str = "constant",
) -> Image: ) -> Image:
output = self._F.pad_image_tensor(self, padding, fill=fill, padding_mode=padding_mode) output = self._F.pad_image_tensor(self, padding, fill=fill, padding_mode=padding_mode)
...@@ -192,7 +188,7 @@ class Image(_Feature): ...@@ -192,7 +188,7 @@ class Image(_Feature):
angle: float, angle: float,
interpolation: InterpolationMode = InterpolationMode.NEAREST, interpolation: InterpolationMode = InterpolationMode.NEAREST,
expand: bool = False, expand: bool = False,
fill: Optional[Union[int, float, List[float]]] = None, fill: FillTypeJIT = None,
center: Optional[List[float]] = None, center: Optional[List[float]] = None,
) -> Image: ) -> Image:
output = self._F._geometry.rotate_image_tensor( output = self._F._geometry.rotate_image_tensor(
...@@ -207,7 +203,7 @@ class Image(_Feature): ...@@ -207,7 +203,7 @@ class Image(_Feature):
scale: float, scale: float,
shear: List[float], shear: List[float],
interpolation: InterpolationMode = InterpolationMode.NEAREST, interpolation: InterpolationMode = InterpolationMode.NEAREST,
fill: Optional[Union[int, float, List[float]]] = None, fill: FillTypeJIT = None,
center: Optional[List[float]] = None, center: Optional[List[float]] = None,
) -> Image: ) -> Image:
output = self._F._geometry.affine_image_tensor( output = self._F._geometry.affine_image_tensor(
...@@ -226,7 +222,7 @@ class Image(_Feature): ...@@ -226,7 +222,7 @@ class Image(_Feature):
self, self,
perspective_coeffs: List[float], perspective_coeffs: List[float],
interpolation: InterpolationMode = InterpolationMode.BILINEAR, interpolation: InterpolationMode = InterpolationMode.BILINEAR,
fill: Optional[Union[int, float, List[float]]] = None, fill: FillTypeJIT = None,
) -> Image: ) -> Image:
output = self._F._geometry.perspective_image_tensor( output = self._F._geometry.perspective_image_tensor(
self, perspective_coeffs, interpolation=interpolation, fill=fill self, perspective_coeffs, interpolation=interpolation, fill=fill
...@@ -237,7 +233,7 @@ class Image(_Feature): ...@@ -237,7 +233,7 @@ class Image(_Feature):
self, self,
displacement: torch.Tensor, displacement: torch.Tensor,
interpolation: InterpolationMode = InterpolationMode.BILINEAR, interpolation: InterpolationMode = InterpolationMode.BILINEAR,
fill: Optional[Union[int, float, List[float]]] = None, fill: FillTypeJIT = None,
) -> Image: ) -> Image:
output = self._F._geometry.elastic_image_tensor(self, displacement, interpolation=interpolation, fill=fill) output = self._F._geometry.elastic_image_tensor(self, displacement, interpolation=interpolation, fill=fill)
return Image.new_like(self, output) return Image.new_like(self, output)
...@@ -289,3 +285,11 @@ class Image(_Feature): ...@@ -289,3 +285,11 @@ class Image(_Feature):
def gaussian_blur(self, kernel_size: List[int], sigma: Optional[List[float]] = None) -> Image: def gaussian_blur(self, kernel_size: List[int], sigma: Optional[List[float]] = None) -> Image:
output = self._F.gaussian_blur_image_tensor(self, kernel_size=kernel_size, sigma=sigma) output = self._F.gaussian_blur_image_tensor(self, kernel_size=kernel_size, sigma=sigma)
return Image.new_like(self, output) return Image.new_like(self, output)
ImageType = Union[torch.Tensor, PIL.Image.Image, Image]
ImageTypeJIT = torch.Tensor
LegacyImageType = Union[torch.Tensor, PIL.Image.Image]
LegacyImageTypeJIT = torch.Tensor
TensorImageType = Union[torch.Tensor, Image]
TensorImageTypeJIT = torch.Tensor
...@@ -5,7 +5,7 @@ from typing import List, Optional, Union ...@@ -5,7 +5,7 @@ from typing import List, Optional, Union
import torch import torch
from torchvision.transforms import InterpolationMode from torchvision.transforms import InterpolationMode
from ._feature import _Feature from ._feature import _Feature, FillTypeJIT
class Mask(_Feature): class Mask(_Feature):
...@@ -51,7 +51,7 @@ class Mask(_Feature): ...@@ -51,7 +51,7 @@ class Mask(_Feature):
def pad( def pad(
self, self,
padding: Union[int, List[int]], padding: Union[int, List[int]],
fill: Optional[Union[int, float, List[float]]] = None, fill: FillTypeJIT = None,
padding_mode: str = "constant", padding_mode: str = "constant",
) -> Mask: ) -> Mask:
output = self._F.pad_mask(self, padding, padding_mode=padding_mode, fill=fill) output = self._F.pad_mask(self, padding, padding_mode=padding_mode, fill=fill)
...@@ -62,7 +62,7 @@ class Mask(_Feature): ...@@ -62,7 +62,7 @@ class Mask(_Feature):
angle: float, angle: float,
interpolation: InterpolationMode = InterpolationMode.NEAREST, interpolation: InterpolationMode = InterpolationMode.NEAREST,
expand: bool = False, expand: bool = False,
fill: Optional[Union[int, float, List[float]]] = None, fill: FillTypeJIT = None,
center: Optional[List[float]] = None, center: Optional[List[float]] = None,
) -> Mask: ) -> Mask:
output = self._F.rotate_mask(self, angle, expand=expand, center=center, fill=fill) output = self._F.rotate_mask(self, angle, expand=expand, center=center, fill=fill)
...@@ -75,7 +75,7 @@ class Mask(_Feature): ...@@ -75,7 +75,7 @@ class Mask(_Feature):
scale: float, scale: float,
shear: List[float], shear: List[float],
interpolation: InterpolationMode = InterpolationMode.NEAREST, interpolation: InterpolationMode = InterpolationMode.NEAREST,
fill: Optional[Union[int, float, List[float]]] = None, fill: FillTypeJIT = None,
center: Optional[List[float]] = None, center: Optional[List[float]] = None,
) -> Mask: ) -> Mask:
output = self._F.affine_mask( output = self._F.affine_mask(
...@@ -93,7 +93,7 @@ class Mask(_Feature): ...@@ -93,7 +93,7 @@ class Mask(_Feature):
self, self,
perspective_coeffs: List[float], perspective_coeffs: List[float],
interpolation: InterpolationMode = InterpolationMode.NEAREST, interpolation: InterpolationMode = InterpolationMode.NEAREST,
fill: Optional[Union[int, float, List[float]]] = None, fill: FillTypeJIT = None,
) -> Mask: ) -> Mask:
output = self._F.perspective_mask(self, perspective_coeffs, fill=fill) output = self._F.perspective_mask(self, perspective_coeffs, fill=fill)
return Mask.new_like(self, output) return Mask.new_like(self, output)
...@@ -102,7 +102,7 @@ class Mask(_Feature): ...@@ -102,7 +102,7 @@ class Mask(_Feature):
self, self,
displacement: torch.Tensor, displacement: torch.Tensor,
interpolation: InterpolationMode = InterpolationMode.NEAREST, interpolation: InterpolationMode = InterpolationMode.NEAREST,
fill: Optional[Union[int, float, List[float]]] = None, fill: FillTypeJIT = None,
) -> Mask: ) -> Mask:
output = self._F.elastic_mask(self, displacement, fill=fill) output = self._F.elastic_mask(self, displacement, fill=fill)
return Mask.new_like(self, output, dtype=output.dtype) return Mask.new_like(self, output, dtype=output.dtype)
import math import math
import numbers import numbers
import warnings import warnings
from typing import Any, Dict, List, Optional, Tuple, Union from typing import Any, cast, Dict, List, Optional, Tuple
import PIL.Image import PIL.Image
import torch import torch
...@@ -92,9 +92,7 @@ class RandomErasing(_RandomApplyTransform): ...@@ -92,9 +92,7 @@ class RandomErasing(_RandomApplyTransform):
return dict(i=i, j=j, h=h, w=w, v=v) return dict(i=i, j=j, h=h, w=w, v=v)
def _transform( def _transform(self, inpt: features.ImageType, params: Dict[str, Any]) -> features.ImageType:
self, inpt: Union[torch.Tensor, features.Image, PIL.Image.Image], params: Dict[str, Any]
) -> Union[torch.Tensor, features.Image, PIL.Image.Image]:
if params["v"] is not None: if params["v"] is not None:
inpt = F.erase(inpt, **params, inplace=self.inplace) inpt = F.erase(inpt, **params, inplace=self.inplace)
...@@ -110,8 +108,10 @@ class _BaseMixupCutmix(_RandomApplyTransform): ...@@ -110,8 +108,10 @@ class _BaseMixupCutmix(_RandomApplyTransform):
def forward(self, *inputs: Any) -> Any: def forward(self, *inputs: Any) -> Any:
if not (has_any(inputs, features.Image, features.is_simple_tensor) and has_any(inputs, features.OneHotLabel)): if not (has_any(inputs, features.Image, features.is_simple_tensor) and has_any(inputs, features.OneHotLabel)):
raise TypeError(f"{type(self).__name__}() is only defined for tensor images and one-hot labels.") raise TypeError(f"{type(self).__name__}() is only defined for tensor images and one-hot labels.")
if has_any(inputs, features.BoundingBox, features.Mask, features.Label): if has_any(inputs, PIL.Image.Image, features.BoundingBox, features.Mask, features.Label):
raise TypeError(f"{type(self).__name__}() does not support bounding boxes, masks and plain labels.") raise TypeError(
f"{type(self).__name__}() does not support PIL images, bounding boxes, masks and plain labels."
)
return super().forward(*inputs) return super().forward(*inputs)
def _mixup_onehotlabel(self, inpt: features.OneHotLabel, lam: float) -> features.OneHotLabel: def _mixup_onehotlabel(self, inpt: features.OneHotLabel, lam: float) -> features.OneHotLabel:
...@@ -203,15 +203,15 @@ class SimpleCopyPaste(_RandomApplyTransform): ...@@ -203,15 +203,15 @@ class SimpleCopyPaste(_RandomApplyTransform):
def _copy_paste( def _copy_paste(
self, self,
image: Any, image: features.TensorImageType,
target: Dict[str, Any], target: Dict[str, Any],
paste_image: Any, paste_image: features.TensorImageType,
paste_target: Dict[str, Any], paste_target: Dict[str, Any],
random_selection: torch.Tensor, random_selection: torch.Tensor,
blending: bool, blending: bool,
resize_interpolation: F.InterpolationMode, resize_interpolation: F.InterpolationMode,
antialias: Optional[bool], antialias: Optional[bool],
) -> Tuple[Any, Dict[str, Any]]: ) -> Tuple[features.TensorImageType, Dict[str, Any]]:
paste_masks = paste_target["masks"].new_like(paste_target["masks"], paste_target["masks"][random_selection]) paste_masks = paste_target["masks"].new_like(paste_target["masks"], paste_target["masks"][random_selection])
paste_boxes = paste_target["boxes"].new_like(paste_target["boxes"], paste_target["boxes"][random_selection]) paste_boxes = paste_target["boxes"].new_like(paste_target["boxes"], paste_target["boxes"][random_selection])
...@@ -223,7 +223,7 @@ class SimpleCopyPaste(_RandomApplyTransform): ...@@ -223,7 +223,7 @@ class SimpleCopyPaste(_RandomApplyTransform):
# This is something different to TF implementation we introduced here as # This is something different to TF implementation we introduced here as
# originally the algorithm works on equal-sized data # originally the algorithm works on equal-sized data
# (for example, coming from LSJ data augmentations) # (for example, coming from LSJ data augmentations)
size1 = image.shape[-2:] size1 = cast(List[int], image.shape[-2:])
size2 = paste_image.shape[-2:] size2 = paste_image.shape[-2:]
if size1 != size2: if size1 != size2:
paste_image = F.resize(paste_image, size=size1, interpolation=resize_interpolation, antialias=antialias) paste_image = F.resize(paste_image, size=size1, interpolation=resize_interpolation, antialias=antialias)
...@@ -278,7 +278,9 @@ class SimpleCopyPaste(_RandomApplyTransform): ...@@ -278,7 +278,9 @@ class SimpleCopyPaste(_RandomApplyTransform):
return image, out_target return image, out_target
def _extract_image_targets(self, flat_sample: List[Any]) -> Tuple[List[Any], List[Dict[str, Any]]]: def _extract_image_targets(
self, flat_sample: List[Any]
) -> Tuple[List[features.TensorImageType], List[Dict[str, Any]]]:
# fetch all images, bboxes, masks and labels from unstructured input # fetch all images, bboxes, masks and labels from unstructured input
# with List[image], List[BoundingBox], List[Mask], List[Label] # with List[image], List[BoundingBox], List[Mask], List[Label]
images, bboxes, masks, labels = [], [], [], [] images, bboxes, masks, labels = [], [], [], []
...@@ -307,7 +309,10 @@ class SimpleCopyPaste(_RandomApplyTransform): ...@@ -307,7 +309,10 @@ class SimpleCopyPaste(_RandomApplyTransform):
return images, targets return images, targets
def _insert_outputs( def _insert_outputs(
self, flat_sample: List[Any], output_images: List[Any], output_targets: List[Dict[str, Any]] self,
flat_sample: List[Any],
output_images: List[features.TensorImageType],
output_targets: List[Dict[str, Any]],
) -> None: ) -> None:
c0, c1, c2, c3 = 0, 0, 0, 0 c0, c1, c2, c3 = 0, 0, 0, 0
for i, obj in enumerate(flat_sample): for i, obj in enumerate(flat_sample):
......
...@@ -9,7 +9,7 @@ from torchvision.prototype import features ...@@ -9,7 +9,7 @@ from torchvision.prototype import features
from torchvision.prototype.transforms import AutoAugmentPolicy, functional as F, InterpolationMode, Transform from torchvision.prototype.transforms import AutoAugmentPolicy, functional as F, InterpolationMode, Transform
from torchvision.prototype.transforms.functional._meta import get_chw from torchvision.prototype.transforms.functional._meta import get_chw
from ._utils import _isinstance, _setup_fill_arg, FillType from ._utils import _isinstance, _setup_fill_arg
K = TypeVar("K") K = TypeVar("K")
V = TypeVar("V") V = TypeVar("V")
...@@ -20,7 +20,7 @@ class _AutoAugmentBase(Transform): ...@@ -20,7 +20,7 @@ class _AutoAugmentBase(Transform):
self, self,
*, *,
interpolation: InterpolationMode = InterpolationMode.NEAREST, interpolation: InterpolationMode = InterpolationMode.NEAREST,
fill: Optional[Union[FillType, Dict[Type, FillType]]] = None, fill: Union[features.FillType, Dict[Type, features.FillType]] = None,
) -> None: ) -> None:
super().__init__() super().__init__()
self.interpolation = interpolation self.interpolation = interpolation
...@@ -35,7 +35,7 @@ class _AutoAugmentBase(Transform): ...@@ -35,7 +35,7 @@ class _AutoAugmentBase(Transform):
self, self,
sample: Any, sample: Any,
unsupported_types: Tuple[Type, ...] = (features.BoundingBox, features.Mask), unsupported_types: Tuple[Type, ...] = (features.BoundingBox, features.Mask),
) -> Tuple[int, Union[PIL.Image.Image, torch.Tensor, features.Image]]: ) -> Tuple[int, features.ImageType]:
sample_flat, _ = tree_flatten(sample) sample_flat, _ = tree_flatten(sample)
images = [] images = []
for id, inpt in enumerate(sample_flat): for id, inpt in enumerate(sample_flat):
...@@ -59,12 +59,12 @@ class _AutoAugmentBase(Transform): ...@@ -59,12 +59,12 @@ class _AutoAugmentBase(Transform):
def _apply_image_transform( def _apply_image_transform(
self, self,
image: Union[torch.Tensor, PIL.Image.Image, features.Image], image: features.ImageType,
transform_id: str, transform_id: str,
magnitude: float, magnitude: float,
interpolation: InterpolationMode, interpolation: InterpolationMode,
fill: Union[Dict[Type, FillType], Dict[Type, None]], fill: Dict[Type, features.FillType],
) -> Any: ) -> features.ImageType:
fill_ = fill[type(image)] fill_ = fill[type(image)]
fill_ = F._geometry._convert_fill_arg(fill_) fill_ = F._geometry._convert_fill_arg(fill_)
...@@ -177,7 +177,7 @@ class AutoAugment(_AutoAugmentBase): ...@@ -177,7 +177,7 @@ class AutoAugment(_AutoAugmentBase):
self, self,
policy: AutoAugmentPolicy = AutoAugmentPolicy.IMAGENET, policy: AutoAugmentPolicy = AutoAugmentPolicy.IMAGENET,
interpolation: InterpolationMode = InterpolationMode.NEAREST, interpolation: InterpolationMode = InterpolationMode.NEAREST,
fill: Optional[Union[FillType, Dict[Type, FillType]]] = None, fill: Union[features.FillType, Dict[Type, features.FillType]] = None,
) -> None: ) -> None:
super().__init__(interpolation=interpolation, fill=fill) super().__init__(interpolation=interpolation, fill=fill)
self.policy = policy self.policy = policy
...@@ -337,7 +337,7 @@ class RandAugment(_AutoAugmentBase): ...@@ -337,7 +337,7 @@ class RandAugment(_AutoAugmentBase):
magnitude: int = 9, magnitude: int = 9,
num_magnitude_bins: int = 31, num_magnitude_bins: int = 31,
interpolation: InterpolationMode = InterpolationMode.NEAREST, interpolation: InterpolationMode = InterpolationMode.NEAREST,
fill: Optional[Union[FillType, Dict[Type, FillType]]] = None, fill: Union[features.FillType, Dict[Type, features.FillType]] = None,
) -> None: ) -> None:
super().__init__(interpolation=interpolation, fill=fill) super().__init__(interpolation=interpolation, fill=fill)
self.num_ops = num_ops self.num_ops = num_ops
...@@ -393,7 +393,7 @@ class TrivialAugmentWide(_AutoAugmentBase): ...@@ -393,7 +393,7 @@ class TrivialAugmentWide(_AutoAugmentBase):
self, self,
num_magnitude_bins: int = 31, num_magnitude_bins: int = 31,
interpolation: InterpolationMode = InterpolationMode.NEAREST, interpolation: InterpolationMode = InterpolationMode.NEAREST,
fill: Optional[Union[FillType, Dict[Type, FillType]]] = None, fill: Union[features.FillType, Dict[Type, features.FillType]] = None,
): ):
super().__init__(interpolation=interpolation, fill=fill) super().__init__(interpolation=interpolation, fill=fill)
self.num_magnitude_bins = num_magnitude_bins self.num_magnitude_bins = num_magnitude_bins
...@@ -453,7 +453,7 @@ class AugMix(_AutoAugmentBase): ...@@ -453,7 +453,7 @@ class AugMix(_AutoAugmentBase):
alpha: float = 1.0, alpha: float = 1.0,
all_ops: bool = True, all_ops: bool = True,
interpolation: InterpolationMode = InterpolationMode.BILINEAR, interpolation: InterpolationMode = InterpolationMode.BILINEAR,
fill: Optional[Union[FillType, Dict[Type, FillType]]] = None, fill: Union[features.FillType, Dict[Type, features.FillType]] = None,
) -> None: ) -> None:
super().__init__(interpolation=interpolation, fill=fill) super().__init__(interpolation=interpolation, fill=fill)
self._PARAMETER_MAX = 10 self._PARAMETER_MAX = 10
......
import collections.abc import collections.abc
from typing import Any, Dict, Optional, Sequence, Tuple, TypeVar, Union from typing import Any, Dict, Optional, Sequence, Tuple, Union
import PIL.Image import PIL.Image
import torch import torch
...@@ -9,8 +9,6 @@ from torchvision.prototype.transforms import functional as F, Transform ...@@ -9,8 +9,6 @@ from torchvision.prototype.transforms import functional as F, Transform
from ._transform import _RandomApplyTransform from ._transform import _RandomApplyTransform
from ._utils import query_chw from ._utils import query_chw
T = TypeVar("T", features.Image, torch.Tensor, PIL.Image.Image)
class ColorJitter(Transform): class ColorJitter(Transform):
def __init__( def __init__(
...@@ -112,7 +110,7 @@ class RandomPhotometricDistort(Transform): ...@@ -112,7 +110,7 @@ class RandomPhotometricDistort(Transform):
channel_permutation=torch.randperm(num_channels) if torch.rand(()) < self.p else None, channel_permutation=torch.randperm(num_channels) if torch.rand(()) < self.p else None,
) )
def _permute_channels(self, inpt: Any, permutation: torch.Tensor) -> Any: def _permute_channels(self, inpt: features.ImageType, permutation: torch.Tensor) -> features.ImageType:
if isinstance(inpt, PIL.Image.Image): if isinstance(inpt, PIL.Image.Image):
inpt = F.pil_to_tensor(inpt) inpt = F.pil_to_tensor(inpt)
...@@ -125,9 +123,7 @@ class RandomPhotometricDistort(Transform): ...@@ -125,9 +123,7 @@ class RandomPhotometricDistort(Transform):
return output return output
def _transform( def _transform(self, inpt: features.ImageType, params: Dict[str, Any]) -> features.ImageType:
self, inpt: Union[torch.Tensor, features.Image, PIL.Image.Image], params: Dict[str, Any]
) -> Union[torch.Tensor, features.Image, PIL.Image.Image]:
if params["brightness"]: if params["brightness"]:
inpt = F.adjust_brightness( inpt = F.adjust_brightness(
inpt, brightness_factor=ColorJitter._generate_value(self.brightness[0], self.brightness[1]) inpt, brightness_factor=ColorJitter._generate_value(self.brightness[0], self.brightness[1])
......
...@@ -11,7 +11,7 @@ from torchvision.transforms import functional as _F ...@@ -11,7 +11,7 @@ from torchvision.transforms import functional as _F
from typing_extensions import Literal from typing_extensions import Literal
from ._transform import _RandomApplyTransform from ._transform import _RandomApplyTransform
from ._utils import DType, query_chw from ._utils import query_chw
class ToTensor(Transform): class ToTensor(Transform):
...@@ -52,7 +52,7 @@ class Grayscale(Transform): ...@@ -52,7 +52,7 @@ class Grayscale(Transform):
super().__init__() super().__init__()
self.num_output_channels = num_output_channels self.num_output_channels = num_output_channels
def _transform(self, inpt: DType, params: Dict[str, Any]) -> DType: def _transform(self, inpt: features.ImageType, params: Dict[str, Any]) -> features.ImageType:
output = _F.rgb_to_grayscale(inpt, num_output_channels=self.num_output_channels) output = _F.rgb_to_grayscale(inpt, num_output_channels=self.num_output_channels)
if isinstance(inpt, features.Image): if isinstance(inpt, features.Image):
output = features.Image.new_like(inpt, output, color_space=features.ColorSpace.GRAY) output = features.Image.new_like(inpt, output, color_space=features.ColorSpace.GRAY)
...@@ -81,7 +81,7 @@ class RandomGrayscale(_RandomApplyTransform): ...@@ -81,7 +81,7 @@ class RandomGrayscale(_RandomApplyTransform):
num_input_channels, _, _ = query_chw(sample) num_input_channels, _, _ = query_chw(sample)
return dict(num_input_channels=num_input_channels) return dict(num_input_channels=num_input_channels)
def _transform(self, inpt: DType, params: Dict[str, Any]) -> DType: def _transform(self, inpt: features.ImageType, params: Dict[str, Any]) -> features.ImageType:
output = _F.rgb_to_grayscale(inpt, num_output_channels=params["num_input_channels"]) output = _F.rgb_to_grayscale(inpt, num_output_channels=params["num_input_channels"])
if isinstance(inpt, features.Image): if isinstance(inpt, features.Image):
output = features.Image.new_like(inpt, output, color_space=features.ColorSpace.GRAY) output = features.Image.new_like(inpt, output, color_space=features.ColorSpace.GRAY)
......
...@@ -20,8 +20,6 @@ from ._utils import ( ...@@ -20,8 +20,6 @@ from ._utils import (
_setup_angle, _setup_angle,
_setup_fill_arg, _setup_fill_arg,
_setup_size, _setup_size,
DType,
FillType,
has_all, has_all,
has_any, has_any,
query_bounding_box, query_bounding_box,
...@@ -179,7 +177,9 @@ class FiveCrop(Transform): ...@@ -179,7 +177,9 @@ class FiveCrop(Transform):
super().__init__() super().__init__()
self.size = _setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.") self.size = _setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.")
def _transform(self, inpt: DType, params: Dict[str, Any]) -> Tuple[DType, DType, DType, DType, DType]: def _transform(
self, inpt: features.ImageType, params: Dict[str, Any]
) -> Tuple[features.ImageType, features.ImageType, features.ImageType, features.ImageType, features.ImageType]:
return F.five_crop(inpt, self.size) return F.five_crop(inpt, self.size)
def forward(self, *inputs: Any) -> Any: def forward(self, *inputs: Any) -> Any:
...@@ -200,7 +200,7 @@ class TenCrop(Transform): ...@@ -200,7 +200,7 @@ class TenCrop(Transform):
self.size = _setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.") self.size = _setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.")
self.vertical_flip = vertical_flip self.vertical_flip = vertical_flip
def _transform(self, inpt: DType, params: Dict[str, Any]) -> List[DType]: def _transform(self, inpt: features.ImageType, params: Dict[str, Any]) -> List[features.ImageType]:
return F.ten_crop(inpt, self.size, vertical_flip=self.vertical_flip) return F.ten_crop(inpt, self.size, vertical_flip=self.vertical_flip)
def forward(self, *inputs: Any) -> Any: def forward(self, *inputs: Any) -> Any:
...@@ -213,7 +213,7 @@ class Pad(Transform): ...@@ -213,7 +213,7 @@ class Pad(Transform):
def __init__( def __init__(
self, self,
padding: Union[int, Sequence[int]], padding: Union[int, Sequence[int]],
fill: Union[FillType, Dict[Type, FillType]] = 0, fill: Union[features.FillType, Dict[Type, features.FillType]] = 0,
padding_mode: Literal["constant", "edge", "reflect", "symmetric"] = "constant", padding_mode: Literal["constant", "edge", "reflect", "symmetric"] = "constant",
) -> None: ) -> None:
super().__init__() super().__init__()
...@@ -240,7 +240,7 @@ class Pad(Transform): ...@@ -240,7 +240,7 @@ class Pad(Transform):
class RandomZoomOut(_RandomApplyTransform): class RandomZoomOut(_RandomApplyTransform):
def __init__( def __init__(
self, self,
fill: Union[FillType, Dict[Type, FillType]] = 0, fill: Union[features.FillType, Dict[Type, features.FillType]] = 0,
side_range: Sequence[float] = (1.0, 4.0), side_range: Sequence[float] = (1.0, 4.0),
p: float = 0.5, p: float = 0.5,
) -> None: ) -> None:
...@@ -282,7 +282,7 @@ class RandomRotation(Transform): ...@@ -282,7 +282,7 @@ class RandomRotation(Transform):
degrees: Union[numbers.Number, Sequence], degrees: Union[numbers.Number, Sequence],
interpolation: InterpolationMode = InterpolationMode.NEAREST, interpolation: InterpolationMode = InterpolationMode.NEAREST,
expand: bool = False, expand: bool = False,
fill: Union[FillType, Dict[Type, FillType]] = 0, fill: Union[features.FillType, Dict[Type, features.FillType]] = 0,
center: Optional[List[float]] = None, center: Optional[List[float]] = None,
) -> None: ) -> None:
super().__init__() super().__init__()
...@@ -322,7 +322,7 @@ class RandomAffine(Transform): ...@@ -322,7 +322,7 @@ class RandomAffine(Transform):
scale: Optional[Sequence[float]] = None, scale: Optional[Sequence[float]] = None,
shear: Optional[Union[float, Sequence[float]]] = None, shear: Optional[Union[float, Sequence[float]]] = None,
interpolation: InterpolationMode = InterpolationMode.NEAREST, interpolation: InterpolationMode = InterpolationMode.NEAREST,
fill: Union[FillType, Dict[Type, FillType]] = 0, fill: Union[features.FillType, Dict[Type, features.FillType]] = 0,
center: Optional[List[float]] = None, center: Optional[List[float]] = None,
) -> None: ) -> None:
super().__init__() super().__init__()
...@@ -401,7 +401,7 @@ class RandomCrop(Transform): ...@@ -401,7 +401,7 @@ class RandomCrop(Transform):
size: Union[int, Sequence[int]], size: Union[int, Sequence[int]],
padding: Optional[Union[int, Sequence[int]]] = None, padding: Optional[Union[int, Sequence[int]]] = None,
pad_if_needed: bool = False, pad_if_needed: bool = False,
fill: Union[FillType, Dict[Type, FillType]] = 0, fill: Union[features.FillType, Dict[Type, features.FillType]] = 0,
padding_mode: Literal["constant", "edge", "reflect", "symmetric"] = "constant", padding_mode: Literal["constant", "edge", "reflect", "symmetric"] = "constant",
) -> None: ) -> None:
super().__init__() super().__init__()
...@@ -491,7 +491,7 @@ class RandomPerspective(_RandomApplyTransform): ...@@ -491,7 +491,7 @@ class RandomPerspective(_RandomApplyTransform):
def __init__( def __init__(
self, self,
distortion_scale: float = 0.5, distortion_scale: float = 0.5,
fill: Union[FillType, Dict[Type, FillType]] = 0, fill: Union[features.FillType, Dict[Type, features.FillType]] = 0,
interpolation: InterpolationMode = InterpolationMode.BILINEAR, interpolation: InterpolationMode = InterpolationMode.BILINEAR,
p: float = 0.5, p: float = 0.5,
) -> None: ) -> None:
...@@ -567,7 +567,7 @@ class ElasticTransform(Transform): ...@@ -567,7 +567,7 @@ class ElasticTransform(Transform):
self, self,
alpha: Union[float, Sequence[float]] = 50.0, alpha: Union[float, Sequence[float]] = 50.0,
sigma: Union[float, Sequence[float]] = 5.0, sigma: Union[float, Sequence[float]] = 5.0,
fill: Union[FillType, Dict[Type, FillType]] = 0, fill: Union[features.FillType, Dict[Type, features.FillType]] = 0,
interpolation: InterpolationMode = InterpolationMode.BILINEAR, interpolation: InterpolationMode = InterpolationMode.BILINEAR,
) -> None: ) -> None:
super().__init__() super().__init__()
...@@ -780,7 +780,7 @@ class FixedSizeCrop(Transform): ...@@ -780,7 +780,7 @@ class FixedSizeCrop(Transform):
def __init__( def __init__(
self, self,
size: Union[int, Sequence[int]], size: Union[int, Sequence[int]],
fill: Union[FillType, Dict[Type, FillType]] = 0, fill: Union[features.FillType, Dict[Type, features.FillType]] = 0,
padding_mode: str = "constant", padding_mode: str = "constant",
) -> None: ) -> None:
super().__init__() super().__init__()
......
...@@ -28,9 +28,7 @@ class ConvertImageDtype(Transform): ...@@ -28,9 +28,7 @@ class ConvertImageDtype(Transform):
super().__init__() super().__init__()
self.dtype = dtype self.dtype = dtype
def _transform( def _transform(self, inpt: features.TensorImageType, params: Dict[str, Any]) -> features.TensorImageType:
self, inpt: Union[torch.Tensor, features.Image], params: Dict[str, Any]
) -> Union[torch.Tensor, features.Image]:
output = F.convert_image_dtype(inpt, dtype=self.dtype) output = F.convert_image_dtype(inpt, dtype=self.dtype)
return output if features.is_simple_tensor(inpt) else features.Image.new_like(inpt, output, dtype=self.dtype) # type: ignore[arg-type] return output if features.is_simple_tensor(inpt) else features.Image.new_like(inpt, output, dtype=self.dtype) # type: ignore[arg-type]
...@@ -56,9 +54,7 @@ class ConvertColorSpace(Transform): ...@@ -56,9 +54,7 @@ class ConvertColorSpace(Transform):
self.copy = copy self.copy = copy
def _transform( def _transform(self, inpt: features.ImageType, params: Dict[str, Any]) -> features.ImageType:
self, inpt: Union[torch.Tensor, PIL.Image.Image, features._Feature], params: Dict[str, Any]
) -> Union[torch.Tensor, PIL.Image.Image, features._Feature]:
return F.convert_color_space( return F.convert_color_space(
inpt, color_space=self.color_space, old_color_space=self.old_color_space, copy=self.copy inpt, color_space=self.color_space, old_color_space=self.old_color_space, copy=self.copy
) )
......
...@@ -68,7 +68,7 @@ class LinearTransformation(Transform): ...@@ -68,7 +68,7 @@ class LinearTransformation(Transform):
return super().forward(*inputs) return super().forward(*inputs)
def _transform(self, inpt: Union[torch.Tensor, features._Feature], params: Dict[str, Any]) -> torch.Tensor: def _transform(self, inpt: features.TensorImageType, params: Dict[str, Any]) -> torch.Tensor:
# Image instance after linear transformation is not Image anymore due to unknown data range # Image instance after linear transformation is not Image anymore due to unknown data range
# Thus we will return Tensor for input Image # Thus we will return Tensor for input Image
...@@ -101,7 +101,7 @@ class Normalize(Transform): ...@@ -101,7 +101,7 @@ class Normalize(Transform):
self.std = list(std) self.std = list(std)
self.inplace = inplace self.inplace = inplace
def _transform(self, inpt: Union[torch.Tensor, features._Feature], params: Dict[str, Any]) -> torch.Tensor: def _transform(self, inpt: features.TensorImageType, params: Dict[str, Any]) -> torch.Tensor:
return F.normalize(inpt, mean=self.mean, std=self.std, inplace=self.inplace) return F.normalize(inpt, mean=self.mean, std=self.std, inplace=self.inplace)
def forward(self, *inpts: Any) -> Any: def forward(self, *inpts: Any) -> Any:
......
import numbers import numbers
from collections import defaultdict from collections import defaultdict
from typing import Any, Callable, Dict, Optional, Sequence, Tuple, Type, Union from typing import Any, Callable, Dict, Sequence, Tuple, Type, Union
import PIL.Image import PIL.Image
import torch
from torch.utils._pytree import tree_flatten from torch.utils._pytree import tree_flatten
from torchvision._utils import sequence_to_str from torchvision._utils import sequence_to_str
from torchvision.prototype import features from torchvision.prototype import features
from torchvision.prototype.features._feature import FillType
from torchvision.prototype.transforms.functional._meta import get_chw from torchvision.prototype.transforms.functional._meta import get_chw
from torchvision.transforms.transforms import _check_sequence_input, _setup_angle, _setup_size # noqa: F401 from torchvision.transforms.transforms import _check_sequence_input, _setup_angle, _setup_size # noqa: F401
...@@ -16,12 +16,7 @@ from torchvision.transforms.transforms import _check_sequence_input, _setup_angl ...@@ -16,12 +16,7 @@ from torchvision.transforms.transforms import _check_sequence_input, _setup_angl
from typing_extensions import Literal from typing_extensions import Literal
# Type shortcuts: def _check_fill_arg(fill: Union[FillType, Dict[Type, FillType]]) -> None:
DType = Union[torch.Tensor, PIL.Image.Image, features._Feature]
FillType = Union[int, float, Sequence[int], Sequence[float]]
def _check_fill_arg(fill: Optional[Union[FillType, Dict[Type, FillType]]]) -> None:
if isinstance(fill, dict): if isinstance(fill, dict):
for key, value in fill.items(): for key, value in fill.items():
# Check key for type # Check key for type
...@@ -31,15 +26,13 @@ def _check_fill_arg(fill: Optional[Union[FillType, Dict[Type, FillType]]]) -> No ...@@ -31,15 +26,13 @@ def _check_fill_arg(fill: Optional[Union[FillType, Dict[Type, FillType]]]) -> No
raise TypeError("Got inappropriate fill arg") raise TypeError("Got inappropriate fill arg")
def _setup_fill_arg( def _setup_fill_arg(fill: Union[FillType, Dict[Type, FillType]]) -> Dict[Type, FillType]:
fill: Optional[Union[FillType, Dict[Type, FillType]]]
) -> Union[Dict[Type, FillType], Dict[Type, None]]:
_check_fill_arg(fill) _check_fill_arg(fill)
if isinstance(fill, dict): if isinstance(fill, dict):
return fill return fill
return defaultdict(lambda: fill) # type: ignore[return-value] return defaultdict(lambda: fill) # type: ignore[return-value, arg-type]
def _check_padding_arg(padding: Union[int, Sequence[int]]) -> None: def _check_padding_arg(padding: Union[int, Sequence[int]]) -> None:
......
...@@ -5,7 +5,6 @@ from torchvision.prototype import features ...@@ -5,7 +5,6 @@ from torchvision.prototype import features
from torchvision.transforms import functional_tensor as _FT from torchvision.transforms import functional_tensor as _FT
from torchvision.transforms.functional import pil_to_tensor, to_pil_image from torchvision.transforms.functional import pil_to_tensor, to_pil_image
erase_image_tensor = _FT.erase erase_image_tensor = _FT.erase
...@@ -19,14 +18,14 @@ def erase_image_pil( ...@@ -19,14 +18,14 @@ def erase_image_pil(
def erase( def erase(
inpt: features.ImageType, inpt: features.ImageTypeJIT,
i: int, i: int,
j: int, j: int,
h: int, h: int,
w: int, w: int,
v: torch.Tensor, v: torch.Tensor,
inplace: bool = False, inplace: bool = False,
) -> features.ImageType: ) -> features.ImageTypeJIT:
if isinstance(inpt, torch.Tensor): if isinstance(inpt, torch.Tensor):
output = erase_image_tensor(inpt, i=i, j=j, h=h, w=w, v=v, inplace=inplace) output = erase_image_tensor(inpt, i=i, j=j, h=h, w=w, v=v, inplace=inplace)
if not torch.jit.is_scripting() and isinstance(inpt, features.Image): if not torch.jit.is_scripting() and isinstance(inpt, features.Image):
......
...@@ -2,12 +2,11 @@ import torch ...@@ -2,12 +2,11 @@ import torch
from torchvision.prototype import features from torchvision.prototype import features
from torchvision.transforms import functional_pil as _FP, functional_tensor as _FT from torchvision.transforms import functional_pil as _FP, functional_tensor as _FT
adjust_brightness_image_tensor = _FT.adjust_brightness adjust_brightness_image_tensor = _FT.adjust_brightness
adjust_brightness_image_pil = _FP.adjust_brightness adjust_brightness_image_pil = _FP.adjust_brightness
def adjust_brightness(inpt: features.DType, brightness_factor: float) -> features.DType: def adjust_brightness(inpt: features.InputTypeJIT, brightness_factor: float) -> features.InputTypeJIT:
if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, features._Feature)): if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, features._Feature)):
return adjust_brightness_image_tensor(inpt, brightness_factor=brightness_factor) return adjust_brightness_image_tensor(inpt, brightness_factor=brightness_factor)
elif isinstance(inpt, features._Feature): elif isinstance(inpt, features._Feature):
...@@ -20,7 +19,7 @@ adjust_saturation_image_tensor = _FT.adjust_saturation ...@@ -20,7 +19,7 @@ adjust_saturation_image_tensor = _FT.adjust_saturation
adjust_saturation_image_pil = _FP.adjust_saturation adjust_saturation_image_pil = _FP.adjust_saturation
def adjust_saturation(inpt: features.DType, saturation_factor: float) -> features.DType: def adjust_saturation(inpt: features.InputTypeJIT, saturation_factor: float) -> features.InputTypeJIT:
if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, features._Feature)): if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, features._Feature)):
return adjust_saturation_image_tensor(inpt, saturation_factor=saturation_factor) return adjust_saturation_image_tensor(inpt, saturation_factor=saturation_factor)
elif isinstance(inpt, features._Feature): elif isinstance(inpt, features._Feature):
...@@ -33,7 +32,7 @@ adjust_contrast_image_tensor = _FT.adjust_contrast ...@@ -33,7 +32,7 @@ adjust_contrast_image_tensor = _FT.adjust_contrast
adjust_contrast_image_pil = _FP.adjust_contrast adjust_contrast_image_pil = _FP.adjust_contrast
def adjust_contrast(inpt: features.DType, contrast_factor: float) -> features.DType: def adjust_contrast(inpt: features.InputTypeJIT, contrast_factor: float) -> features.InputTypeJIT:
if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, features._Feature)): if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, features._Feature)):
return adjust_contrast_image_tensor(inpt, contrast_factor=contrast_factor) return adjust_contrast_image_tensor(inpt, contrast_factor=contrast_factor)
elif isinstance(inpt, features._Feature): elif isinstance(inpt, features._Feature):
...@@ -46,7 +45,7 @@ adjust_sharpness_image_tensor = _FT.adjust_sharpness ...@@ -46,7 +45,7 @@ adjust_sharpness_image_tensor = _FT.adjust_sharpness
adjust_sharpness_image_pil = _FP.adjust_sharpness adjust_sharpness_image_pil = _FP.adjust_sharpness
def adjust_sharpness(inpt: features.DType, sharpness_factor: float) -> features.DType: def adjust_sharpness(inpt: features.InputTypeJIT, sharpness_factor: float) -> features.InputTypeJIT:
if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, features._Feature)): if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, features._Feature)):
return adjust_sharpness_image_tensor(inpt, sharpness_factor=sharpness_factor) return adjust_sharpness_image_tensor(inpt, sharpness_factor=sharpness_factor)
elif isinstance(inpt, features._Feature): elif isinstance(inpt, features._Feature):
...@@ -59,7 +58,7 @@ adjust_hue_image_tensor = _FT.adjust_hue ...@@ -59,7 +58,7 @@ adjust_hue_image_tensor = _FT.adjust_hue
adjust_hue_image_pil = _FP.adjust_hue adjust_hue_image_pil = _FP.adjust_hue
def adjust_hue(inpt: features.DType, hue_factor: float) -> features.DType: def adjust_hue(inpt: features.InputTypeJIT, hue_factor: float) -> features.InputTypeJIT:
if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, features._Feature)): if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, features._Feature)):
return adjust_hue_image_tensor(inpt, hue_factor=hue_factor) return adjust_hue_image_tensor(inpt, hue_factor=hue_factor)
elif isinstance(inpt, features._Feature): elif isinstance(inpt, features._Feature):
...@@ -72,7 +71,7 @@ adjust_gamma_image_tensor = _FT.adjust_gamma ...@@ -72,7 +71,7 @@ adjust_gamma_image_tensor = _FT.adjust_gamma
adjust_gamma_image_pil = _FP.adjust_gamma adjust_gamma_image_pil = _FP.adjust_gamma
def adjust_gamma(inpt: features.DType, gamma: float, gain: float = 1) -> features.DType: def adjust_gamma(inpt: features.InputTypeJIT, gamma: float, gain: float = 1) -> features.InputTypeJIT:
if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, features._Feature)): if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, features._Feature)):
return adjust_gamma_image_tensor(inpt, gamma=gamma, gain=gain) return adjust_gamma_image_tensor(inpt, gamma=gamma, gain=gain)
elif isinstance(inpt, features._Feature): elif isinstance(inpt, features._Feature):
...@@ -85,7 +84,7 @@ posterize_image_tensor = _FT.posterize ...@@ -85,7 +84,7 @@ posterize_image_tensor = _FT.posterize
posterize_image_pil = _FP.posterize posterize_image_pil = _FP.posterize
def posterize(inpt: features.DType, bits: int) -> features.DType: def posterize(inpt: features.InputTypeJIT, bits: int) -> features.InputTypeJIT:
if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, features._Feature)): if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, features._Feature)):
return posterize_image_tensor(inpt, bits=bits) return posterize_image_tensor(inpt, bits=bits)
elif isinstance(inpt, features._Feature): elif isinstance(inpt, features._Feature):
...@@ -98,7 +97,7 @@ solarize_image_tensor = _FT.solarize ...@@ -98,7 +97,7 @@ solarize_image_tensor = _FT.solarize
solarize_image_pil = _FP.solarize solarize_image_pil = _FP.solarize
def solarize(inpt: features.DType, threshold: float) -> features.DType: def solarize(inpt: features.InputTypeJIT, threshold: float) -> features.InputTypeJIT:
if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, features._Feature)): if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, features._Feature)):
return solarize_image_tensor(inpt, threshold=threshold) return solarize_image_tensor(inpt, threshold=threshold)
elif isinstance(inpt, features._Feature): elif isinstance(inpt, features._Feature):
...@@ -111,7 +110,7 @@ autocontrast_image_tensor = _FT.autocontrast ...@@ -111,7 +110,7 @@ autocontrast_image_tensor = _FT.autocontrast
autocontrast_image_pil = _FP.autocontrast autocontrast_image_pil = _FP.autocontrast
def autocontrast(inpt: features.DType) -> features.DType: def autocontrast(inpt: features.InputTypeJIT) -> features.InputTypeJIT:
if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, features._Feature)): if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, features._Feature)):
return autocontrast_image_tensor(inpt) return autocontrast_image_tensor(inpt)
elif isinstance(inpt, features._Feature): elif isinstance(inpt, features._Feature):
...@@ -124,7 +123,7 @@ equalize_image_tensor = _FT.equalize ...@@ -124,7 +123,7 @@ equalize_image_tensor = _FT.equalize
equalize_image_pil = _FP.equalize equalize_image_pil = _FP.equalize
def equalize(inpt: features.DType) -> features.DType: def equalize(inpt: features.InputTypeJIT) -> features.InputTypeJIT:
if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, features._Feature)): if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, features._Feature)):
return equalize_image_tensor(inpt) return equalize_image_tensor(inpt)
elif isinstance(inpt, features._Feature): elif isinstance(inpt, features._Feature):
...@@ -137,7 +136,7 @@ invert_image_tensor = _FT.invert ...@@ -137,7 +136,7 @@ invert_image_tensor = _FT.invert
invert_image_pil = _FP.invert invert_image_pil = _FP.invert
def invert(inpt: features.DType) -> features.DType: def invert(inpt: features.InputTypeJIT) -> features.InputTypeJIT:
if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, features._Feature)): if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, features._Feature)):
return invert_image_tensor(inpt) return invert_image_tensor(inpt)
elif isinstance(inpt, features._Feature): elif isinstance(inpt, features._Feature):
......
...@@ -8,11 +8,6 @@ from torchvision.prototype import features ...@@ -8,11 +8,6 @@ from torchvision.prototype import features
from torchvision.transforms import functional as _F from torchvision.transforms import functional as _F
# Due to torch.jit.script limitation we keep LegacyImageType as torch.Tensor
# instead of Union[torch.Tensor, PIL.Image.Image]
LegacyImageType = torch.Tensor
@torch.jit.unused @torch.jit.unused
def to_grayscale(inpt: PIL.Image.Image, num_output_channels: int = 1) -> PIL.Image.Image: def to_grayscale(inpt: PIL.Image.Image, num_output_channels: int = 1) -> PIL.Image.Image:
call = ", num_output_channels=3" if num_output_channels == 3 else "" call = ", num_output_channels=3" if num_output_channels == 3 else ""
...@@ -27,7 +22,7 @@ def to_grayscale(inpt: PIL.Image.Image, num_output_channels: int = 1) -> PIL.Ima ...@@ -27,7 +22,7 @@ def to_grayscale(inpt: PIL.Image.Image, num_output_channels: int = 1) -> PIL.Ima
return _F.to_grayscale(inpt, num_output_channels=num_output_channels) return _F.to_grayscale(inpt, num_output_channels=num_output_channels)
def rgb_to_grayscale(inpt: LegacyImageType, num_output_channels: int = 1) -> LegacyImageType: def rgb_to_grayscale(inpt: features.LegacyImageTypeJIT, num_output_channels: int = 1) -> features.LegacyImageTypeJIT:
old_color_space = ( old_color_space = (
features._image._from_tensor_shape(inpt.shape) # type: ignore[arg-type] features._image._from_tensor_shape(inpt.shape) # type: ignore[arg-type]
if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, features.Image)) if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, features.Image))
...@@ -61,7 +56,7 @@ def to_tensor(inpt: Any) -> torch.Tensor: ...@@ -61,7 +56,7 @@ def to_tensor(inpt: Any) -> torch.Tensor:
return _F.to_tensor(inpt) return _F.to_tensor(inpt)
def get_image_size(inpt: features.ImageType) -> List[int]: def get_image_size(inpt: features.ImageTypeJIT) -> List[int]:
warnings.warn( warnings.warn(
"The function `get_image_size(...)` is deprecated and will be removed in a future release. " "The function `get_image_size(...)` is deprecated and will be removed in a future release. "
"Instead, please use `get_spatial_size(...)` which returns `[h, w]` instead of `[w, h]`." "Instead, please use `get_spatial_size(...)` which returns `[h, w]` instead of `[w, h]`."
......
...@@ -18,7 +18,6 @@ from torchvision.transforms.functional_tensor import _parse_pad_padding ...@@ -18,7 +18,6 @@ from torchvision.transforms.functional_tensor import _parse_pad_padding
from ._meta import convert_format_bounding_box, get_dimensions_image_pil, get_dimensions_image_tensor from ._meta import convert_format_bounding_box, get_dimensions_image_pil, get_dimensions_image_tensor
horizontal_flip_image_tensor = _FT.hflip horizontal_flip_image_tensor = _FT.hflip
horizontal_flip_image_pil = _FP.hflip horizontal_flip_image_pil = _FP.hflip
...@@ -43,7 +42,7 @@ def horizontal_flip_bounding_box( ...@@ -43,7 +42,7 @@ def horizontal_flip_bounding_box(
).view(shape) ).view(shape)
def horizontal_flip(inpt: features.DType) -> features.DType: def horizontal_flip(inpt: features.InputTypeJIT) -> features.InputTypeJIT:
if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, features._Feature)): if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, features._Feature)):
return horizontal_flip_image_tensor(inpt) return horizontal_flip_image_tensor(inpt)
elif isinstance(inpt, features._Feature): elif isinstance(inpt, features._Feature):
...@@ -76,7 +75,7 @@ def vertical_flip_bounding_box( ...@@ -76,7 +75,7 @@ def vertical_flip_bounding_box(
).view(shape) ).view(shape)
def vertical_flip(inpt: features.DType) -> features.DType: def vertical_flip(inpt: features.InputTypeJIT) -> features.InputTypeJIT:
if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, features._Feature)): if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, features._Feature)):
return vertical_flip_image_tensor(inpt) return vertical_flip_image_tensor(inpt)
elif isinstance(inpt, features._Feature): elif isinstance(inpt, features._Feature):
...@@ -153,12 +152,12 @@ def resize_bounding_box( ...@@ -153,12 +152,12 @@ def resize_bounding_box(
def resize( def resize(
inpt: features.DType, inpt: features.InputTypeJIT,
size: List[int], size: List[int],
interpolation: InterpolationMode = InterpolationMode.BILINEAR, interpolation: InterpolationMode = InterpolationMode.BILINEAR,
max_size: Optional[int] = None, max_size: Optional[int] = None,
antialias: Optional[bool] = None, antialias: Optional[bool] = None,
) -> features.DType: ) -> features.InputTypeJIT:
if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, features._Feature)): if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, features._Feature)):
antialias = False if antialias is None else antialias antialias = False if antialias is None else antialias
return resize_image_tensor(inpt, size, interpolation=interpolation, max_size=max_size, antialias=antialias) return resize_image_tensor(inpt, size, interpolation=interpolation, max_size=max_size, antialias=antialias)
...@@ -228,7 +227,7 @@ def affine_image_tensor( ...@@ -228,7 +227,7 @@ def affine_image_tensor(
scale: float, scale: float,
shear: List[float], shear: List[float],
interpolation: InterpolationMode = InterpolationMode.NEAREST, interpolation: InterpolationMode = InterpolationMode.NEAREST,
fill: Optional[Union[int, float, List[float]]] = None, fill: features.FillTypeJIT = None,
center: Optional[List[float]] = None, center: Optional[List[float]] = None,
) -> torch.Tensor: ) -> torch.Tensor:
if img.numel() == 0: if img.numel() == 0:
...@@ -260,7 +259,7 @@ def affine_image_pil( ...@@ -260,7 +259,7 @@ def affine_image_pil(
scale: float, scale: float,
shear: List[float], shear: List[float],
interpolation: InterpolationMode = InterpolationMode.NEAREST, interpolation: InterpolationMode = InterpolationMode.NEAREST,
fill: Optional[Union[int, float, List[float]]] = None, fill: features.FillTypeJIT = None,
center: Optional[List[float]] = None, center: Optional[List[float]] = None,
) -> PIL.Image.Image: ) -> PIL.Image.Image:
angle, translate, shear, center = _affine_parse_args(angle, translate, scale, shear, interpolation, center) angle, translate, shear, center = _affine_parse_args(angle, translate, scale, shear, interpolation, center)
...@@ -378,7 +377,7 @@ def affine_mask( ...@@ -378,7 +377,7 @@ def affine_mask(
translate: List[float], translate: List[float],
scale: float, scale: float,
shear: List[float], shear: List[float],
fill: Optional[Union[int, float, List[float]]] = None, fill: features.FillTypeJIT = None,
center: Optional[List[float]] = None, center: Optional[List[float]] = None,
) -> torch.Tensor: ) -> torch.Tensor:
if mask.ndim < 3: if mask.ndim < 3:
...@@ -404,9 +403,7 @@ def affine_mask( ...@@ -404,9 +403,7 @@ def affine_mask(
return output return output
def _convert_fill_arg( def _convert_fill_arg(fill: features.FillType) -> features.FillTypeJIT:
fill: Optional[Union[int, float, Sequence[int], Sequence[float]]]
) -> Optional[Union[int, float, List[float]]]:
# Fill = 0 is not equivalent to None, https://github.com/pytorch/vision/issues/6517 # Fill = 0 is not equivalent to None, https://github.com/pytorch/vision/issues/6517
# So, we can't reassign fill to 0 # So, we can't reassign fill to 0
# if fill is None: # if fill is None:
...@@ -421,15 +418,15 @@ def _convert_fill_arg( ...@@ -421,15 +418,15 @@ def _convert_fill_arg(
def affine( def affine(
inpt: features.DType, inpt: features.InputTypeJIT,
angle: float, angle: float,
translate: List[float], translate: List[float],
scale: float, scale: float,
shear: List[float], shear: List[float],
interpolation: InterpolationMode = InterpolationMode.NEAREST, interpolation: InterpolationMode = InterpolationMode.NEAREST,
fill: Optional[Union[int, float, List[float]]] = None, fill: features.FillTypeJIT = None,
center: Optional[List[float]] = None, center: Optional[List[float]] = None,
) -> features.DType: ) -> features.InputTypeJIT:
if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, features._Feature)): if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, features._Feature)):
return affine_image_tensor( return affine_image_tensor(
inpt, inpt,
...@@ -463,7 +460,7 @@ def rotate_image_tensor( ...@@ -463,7 +460,7 @@ def rotate_image_tensor(
angle: float, angle: float,
interpolation: InterpolationMode = InterpolationMode.NEAREST, interpolation: InterpolationMode = InterpolationMode.NEAREST,
expand: bool = False, expand: bool = False,
fill: Optional[Union[int, float, List[float]]] = None, fill: features.FillTypeJIT = None,
center: Optional[List[float]] = None, center: Optional[List[float]] = None,
) -> torch.Tensor: ) -> torch.Tensor:
num_channels, height, width = img.shape[-3:] num_channels, height, width = img.shape[-3:]
...@@ -502,7 +499,7 @@ def rotate_image_pil( ...@@ -502,7 +499,7 @@ def rotate_image_pil(
angle: float, angle: float,
interpolation: InterpolationMode = InterpolationMode.NEAREST, interpolation: InterpolationMode = InterpolationMode.NEAREST,
expand: bool = False, expand: bool = False,
fill: Optional[Union[int, float, List[float]]] = None, fill: features.FillTypeJIT = None,
center: Optional[List[float]] = None, center: Optional[List[float]] = None,
) -> PIL.Image.Image: ) -> PIL.Image.Image:
if center is not None and expand: if center is not None and expand:
...@@ -542,7 +539,7 @@ def rotate_mask( ...@@ -542,7 +539,7 @@ def rotate_mask(
mask: torch.Tensor, mask: torch.Tensor,
angle: float, angle: float,
expand: bool = False, expand: bool = False,
fill: Optional[Union[int, float, List[float]]] = None, fill: features.FillTypeJIT = None,
center: Optional[List[float]] = None, center: Optional[List[float]] = None,
) -> torch.Tensor: ) -> torch.Tensor:
if mask.ndim < 3: if mask.ndim < 3:
...@@ -567,13 +564,13 @@ def rotate_mask( ...@@ -567,13 +564,13 @@ def rotate_mask(
def rotate( def rotate(
inpt: features.DType, inpt: features.InputTypeJIT,
angle: float, angle: float,
interpolation: InterpolationMode = InterpolationMode.NEAREST, interpolation: InterpolationMode = InterpolationMode.NEAREST,
expand: bool = False, expand: bool = False,
fill: Optional[Union[int, float, List[float]]] = None, fill: features.FillTypeJIT = None,
center: Optional[List[float]] = None, center: Optional[List[float]] = None,
) -> features.DType: ) -> features.InputTypeJIT:
if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, features._Feature)): if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, features._Feature)):
return rotate_image_tensor(inpt, angle, interpolation=interpolation, expand=expand, fill=fill, center=center) return rotate_image_tensor(inpt, angle, interpolation=interpolation, expand=expand, fill=fill, center=center)
elif isinstance(inpt, features._Feature): elif isinstance(inpt, features._Feature):
...@@ -588,7 +585,7 @@ pad_image_pil = _FP.pad ...@@ -588,7 +585,7 @@ pad_image_pil = _FP.pad
def pad_image_tensor( def pad_image_tensor(
img: torch.Tensor, img: torch.Tensor,
padding: Union[int, List[int]], padding: Union[int, List[int]],
fill: Optional[Union[int, float, List[float]]] = None, fill: features.FillTypeJIT = None,
padding_mode: str = "constant", padding_mode: str = "constant",
) -> torch.Tensor: ) -> torch.Tensor:
if fill is None: if fill is None:
...@@ -652,7 +649,7 @@ def pad_mask( ...@@ -652,7 +649,7 @@ def pad_mask(
mask: torch.Tensor, mask: torch.Tensor,
padding: Union[int, List[int]], padding: Union[int, List[int]],
padding_mode: str = "constant", padding_mode: str = "constant",
fill: Optional[Union[int, float, List[float]]] = None, fill: features.FillTypeJIT = None,
) -> torch.Tensor: ) -> torch.Tensor:
if fill is None: if fill is None:
fill = 0 fill = 0
...@@ -698,11 +695,11 @@ def pad_bounding_box( ...@@ -698,11 +695,11 @@ def pad_bounding_box(
def pad( def pad(
inpt: features.DType, inpt: features.InputTypeJIT,
padding: Union[int, List[int]], padding: Union[int, List[int]],
fill: Optional[Union[int, float, List[float]]] = None, fill: features.FillTypeJIT = None,
padding_mode: str = "constant", padding_mode: str = "constant",
) -> features.DType: ) -> features.InputTypeJIT:
if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, features._Feature)): if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, features._Feature)):
return pad_image_tensor(inpt, padding, fill=fill, padding_mode=padding_mode) return pad_image_tensor(inpt, padding, fill=fill, padding_mode=padding_mode)
...@@ -739,7 +736,7 @@ def crop_mask(mask: torch.Tensor, top: int, left: int, height: int, width: int) ...@@ -739,7 +736,7 @@ def crop_mask(mask: torch.Tensor, top: int, left: int, height: int, width: int)
return crop_image_tensor(mask, top, left, height, width) return crop_image_tensor(mask, top, left, height, width)
def crop(inpt: features.DType, top: int, left: int, height: int, width: int) -> features.DType: def crop(inpt: features.InputTypeJIT, top: int, left: int, height: int, width: int) -> features.InputTypeJIT:
if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, features._Feature)): if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, features._Feature)):
return crop_image_tensor(inpt, top, left, height, width) return crop_image_tensor(inpt, top, left, height, width)
elif isinstance(inpt, features._Feature): elif isinstance(inpt, features._Feature):
...@@ -752,7 +749,7 @@ def perspective_image_tensor( ...@@ -752,7 +749,7 @@ def perspective_image_tensor(
img: torch.Tensor, img: torch.Tensor,
perspective_coeffs: List[float], perspective_coeffs: List[float],
interpolation: InterpolationMode = InterpolationMode.BILINEAR, interpolation: InterpolationMode = InterpolationMode.BILINEAR,
fill: Optional[Union[int, float, List[float]]] = None, fill: features.FillTypeJIT = None,
) -> torch.Tensor: ) -> torch.Tensor:
return _FT.perspective(img, perspective_coeffs, interpolation=interpolation.value, fill=fill) return _FT.perspective(img, perspective_coeffs, interpolation=interpolation.value, fill=fill)
...@@ -762,7 +759,7 @@ def perspective_image_pil( ...@@ -762,7 +759,7 @@ def perspective_image_pil(
img: PIL.Image.Image, img: PIL.Image.Image,
perspective_coeffs: List[float], perspective_coeffs: List[float],
interpolation: InterpolationMode = InterpolationMode.BICUBIC, interpolation: InterpolationMode = InterpolationMode.BICUBIC,
fill: Optional[Union[int, float, List[float]]] = None, fill: features.FillTypeJIT = None,
) -> PIL.Image.Image: ) -> PIL.Image.Image:
return _FP.perspective(img, perspective_coeffs, interpolation=pil_modes_mapping[interpolation], fill=fill) return _FP.perspective(img, perspective_coeffs, interpolation=pil_modes_mapping[interpolation], fill=fill)
...@@ -855,7 +852,7 @@ def perspective_bounding_box( ...@@ -855,7 +852,7 @@ def perspective_bounding_box(
def perspective_mask( def perspective_mask(
mask: torch.Tensor, mask: torch.Tensor,
perspective_coeffs: List[float], perspective_coeffs: List[float],
fill: Optional[Union[int, float, List[float]]] = None, fill: features.FillTypeJIT = None,
) -> torch.Tensor: ) -> torch.Tensor:
if mask.ndim < 3: if mask.ndim < 3:
mask = mask.unsqueeze(0) mask = mask.unsqueeze(0)
...@@ -874,11 +871,11 @@ def perspective_mask( ...@@ -874,11 +871,11 @@ def perspective_mask(
def perspective( def perspective(
inpt: features.DType, inpt: features.InputTypeJIT,
perspective_coeffs: List[float], perspective_coeffs: List[float],
interpolation: InterpolationMode = InterpolationMode.BILINEAR, interpolation: InterpolationMode = InterpolationMode.BILINEAR,
fill: Optional[Union[int, float, List[float]]] = None, fill: features.FillTypeJIT = None,
) -> features.DType: ) -> features.InputTypeJIT:
if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, features._Feature)): if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, features._Feature)):
return perspective_image_tensor(inpt, perspective_coeffs, interpolation=interpolation, fill=fill) return perspective_image_tensor(inpt, perspective_coeffs, interpolation=interpolation, fill=fill)
elif isinstance(inpt, features._Feature): elif isinstance(inpt, features._Feature):
...@@ -891,7 +888,7 @@ def elastic_image_tensor( ...@@ -891,7 +888,7 @@ def elastic_image_tensor(
img: torch.Tensor, img: torch.Tensor,
displacement: torch.Tensor, displacement: torch.Tensor,
interpolation: InterpolationMode = InterpolationMode.BILINEAR, interpolation: InterpolationMode = InterpolationMode.BILINEAR,
fill: Optional[Union[int, float, List[float]]] = None, fill: features.FillTypeJIT = None,
) -> torch.Tensor: ) -> torch.Tensor:
return _FT.elastic_transform(img, displacement, interpolation=interpolation.value, fill=fill) return _FT.elastic_transform(img, displacement, interpolation=interpolation.value, fill=fill)
...@@ -901,7 +898,7 @@ def elastic_image_pil( ...@@ -901,7 +898,7 @@ def elastic_image_pil(
img: PIL.Image.Image, img: PIL.Image.Image,
displacement: torch.Tensor, displacement: torch.Tensor,
interpolation: InterpolationMode = InterpolationMode.BILINEAR, interpolation: InterpolationMode = InterpolationMode.BILINEAR,
fill: Optional[Union[int, float, List[float]]] = None, fill: features.FillTypeJIT = None,
) -> PIL.Image.Image: ) -> PIL.Image.Image:
t_img = pil_to_tensor(img) t_img = pil_to_tensor(img)
output = elastic_image_tensor(t_img, displacement, interpolation=interpolation, fill=fill) output = elastic_image_tensor(t_img, displacement, interpolation=interpolation, fill=fill)
...@@ -951,7 +948,7 @@ def elastic_bounding_box( ...@@ -951,7 +948,7 @@ def elastic_bounding_box(
def elastic_mask( def elastic_mask(
mask: torch.Tensor, mask: torch.Tensor,
displacement: torch.Tensor, displacement: torch.Tensor,
fill: Optional[Union[int, float, List[float]]] = None, fill: features.FillTypeJIT = None,
) -> torch.Tensor: ) -> torch.Tensor:
if mask.ndim < 3: if mask.ndim < 3:
mask = mask.unsqueeze(0) mask = mask.unsqueeze(0)
...@@ -968,11 +965,11 @@ def elastic_mask( ...@@ -968,11 +965,11 @@ def elastic_mask(
def elastic( def elastic(
inpt: features.DType, inpt: features.InputTypeJIT,
displacement: torch.Tensor, displacement: torch.Tensor,
interpolation: InterpolationMode = InterpolationMode.BILINEAR, interpolation: InterpolationMode = InterpolationMode.BILINEAR,
fill: Optional[Union[int, float, List[float]]] = None, fill: features.FillTypeJIT = None,
) -> features.DType: ) -> features.InputTypeJIT:
if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, features._Feature)): if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, features._Feature)):
return elastic_image_tensor(inpt, displacement, interpolation=interpolation, fill=fill) return elastic_image_tensor(inpt, displacement, interpolation=interpolation, fill=fill)
elif isinstance(inpt, features._Feature): elif isinstance(inpt, features._Feature):
...@@ -1069,7 +1066,7 @@ def center_crop_mask(mask: torch.Tensor, output_size: List[int]) -> torch.Tensor ...@@ -1069,7 +1066,7 @@ def center_crop_mask(mask: torch.Tensor, output_size: List[int]) -> torch.Tensor
return output return output
def center_crop(inpt: features.DType, output_size: List[int]) -> features.DType: def center_crop(inpt: features.InputTypeJIT, output_size: List[int]) -> features.InputTypeJIT:
if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, features._Feature)): if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, features._Feature)):
return center_crop_image_tensor(inpt, output_size) return center_crop_image_tensor(inpt, output_size)
elif isinstance(inpt, features._Feature): elif isinstance(inpt, features._Feature):
...@@ -1132,7 +1129,7 @@ def resized_crop_mask( ...@@ -1132,7 +1129,7 @@ def resized_crop_mask(
def resized_crop( def resized_crop(
inpt: features.DType, inpt: features.InputTypeJIT,
top: int, top: int,
left: int, left: int,
height: int, height: int,
...@@ -1140,7 +1137,7 @@ def resized_crop( ...@@ -1140,7 +1137,7 @@ def resized_crop(
size: List[int], size: List[int],
interpolation: InterpolationMode = InterpolationMode.BILINEAR, interpolation: InterpolationMode = InterpolationMode.BILINEAR,
antialias: Optional[bool] = None, antialias: Optional[bool] = None,
) -> features.DType: ) -> features.InputTypeJIT:
if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, features._Feature)): if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, features._Feature)):
antialias = False if antialias is None else antialias antialias = False if antialias is None else antialias
return resized_crop_image_tensor( return resized_crop_image_tensor(
...@@ -1205,9 +1202,11 @@ def five_crop_image_pil( ...@@ -1205,9 +1202,11 @@ def five_crop_image_pil(
def five_crop( def five_crop(
inpt: features.ImageType, size: List[int] inpt: features.ImageTypeJIT, size: List[int]
) -> Tuple[features.ImageType, features.ImageType, features.ImageType, features.ImageType, features.ImageType]: ) -> Tuple[
# TODO: consider breaking BC here to return List[features.ImageType] to align this op with `ten_crop` features.ImageTypeJIT, features.ImageTypeJIT, features.ImageTypeJIT, features.ImageTypeJIT, features.ImageTypeJIT
]:
# TODO: consider breaking BC here to return List[features.ImageTypeJIT] to align this op with `ten_crop`
if isinstance(inpt, torch.Tensor): if isinstance(inpt, torch.Tensor):
output = five_crop_image_tensor(inpt, size) output = five_crop_image_tensor(inpt, size)
if not torch.jit.is_scripting() and isinstance(inpt, features.Image): if not torch.jit.is_scripting() and isinstance(inpt, features.Image):
...@@ -1244,7 +1243,7 @@ def ten_crop_image_pil(img: PIL.Image.Image, size: List[int], vertical_flip: boo ...@@ -1244,7 +1243,7 @@ def ten_crop_image_pil(img: PIL.Image.Image, size: List[int], vertical_flip: boo
return [tl, tr, bl, br, center, tl_flip, tr_flip, bl_flip, br_flip, center_flip] return [tl, tr, bl, br, center, tl_flip, tr_flip, bl_flip, br_flip, center_flip]
def ten_crop(inpt: features.ImageType, size: List[int], vertical_flip: bool = False) -> List[features.ImageType]: def ten_crop(inpt: features.ImageTypeJIT, size: List[int], vertical_flip: bool = False) -> List[features.ImageTypeJIT]:
if isinstance(inpt, torch.Tensor): if isinstance(inpt, torch.Tensor):
output = ten_crop_image_tensor(inpt, size, vertical_flip=vertical_flip) output = ten_crop_image_tensor(inpt, size, vertical_flip=vertical_flip)
if not torch.jit.is_scripting() and isinstance(inpt, features.Image): if not torch.jit.is_scripting() and isinstance(inpt, features.Image):
......
...@@ -11,7 +11,7 @@ get_dimensions_image_pil = _FP.get_dimensions ...@@ -11,7 +11,7 @@ get_dimensions_image_pil = _FP.get_dimensions
# TODO: Should this be prefixed with `_` similar to other methods that don't get exposed by init? # TODO: Should this be prefixed with `_` similar to other methods that don't get exposed by init?
def get_chw(image: features.ImageType) -> Tuple[int, int, int]: def get_chw(image: features.ImageTypeJIT) -> Tuple[int, int, int]:
if isinstance(image, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(image, features.Image)): if isinstance(image, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(image, features.Image)):
channels, height, width = get_dimensions_image_tensor(image) channels, height, width = get_dimensions_image_tensor(image)
elif isinstance(image, features.Image): elif isinstance(image, features.Image):
...@@ -29,11 +29,11 @@ def get_chw(image: features.ImageType) -> Tuple[int, int, int]: ...@@ -29,11 +29,11 @@ def get_chw(image: features.ImageType) -> Tuple[int, int, int]:
# detailed above. # detailed above.
def get_dimensions(image: features.ImageType) -> List[int]: def get_dimensions(image: features.ImageTypeJIT) -> List[int]:
return list(get_chw(image)) return list(get_chw(image))
def get_num_channels(image: features.ImageType) -> int: def get_num_channels(image: features.ImageTypeJIT) -> int:
num_channels, *_ = get_chw(image) num_channels, *_ = get_chw(image)
return num_channels return num_channels
...@@ -43,7 +43,7 @@ def get_num_channels(image: features.ImageType) -> int: ...@@ -43,7 +43,7 @@ def get_num_channels(image: features.ImageType) -> int:
get_image_num_channels = get_num_channels get_image_num_channels = get_num_channels
def get_spatial_size(image: features.ImageType) -> List[int]: def get_spatial_size(image: features.ImageTypeJIT) -> List[int]:
_, *size = get_chw(image) _, *size = get_chw(image)
return size return size
...@@ -208,11 +208,11 @@ def convert_color_space_image_pil( ...@@ -208,11 +208,11 @@ def convert_color_space_image_pil(
def convert_color_space( def convert_color_space(
inpt: features.ImageType, inpt: features.ImageTypeJIT,
color_space: ColorSpace, color_space: ColorSpace,
old_color_space: Optional[ColorSpace] = None, old_color_space: Optional[ColorSpace] = None,
copy: bool = True, copy: bool = True,
) -> features.ImageType: ) -> features.ImageTypeJIT:
if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, features.Image)): if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, features.Image)):
if old_color_space is None: if old_color_space is None:
raise RuntimeError( raise RuntimeError(
...@@ -225,4 +225,4 @@ def convert_color_space( ...@@ -225,4 +225,4 @@ def convert_color_space(
elif isinstance(inpt, features.Image): elif isinstance(inpt, features.Image):
return inpt.to_color_space(color_space, copy=copy) return inpt.to_color_space(color_space, copy=copy)
else: else:
return cast(features.ImageType, convert_color_space_image_pil(inpt, color_space, copy=copy)) return cast(features.ImageTypeJIT, convert_color_space_image_pil(inpt, color_space, copy=copy))
...@@ -6,16 +6,12 @@ from torchvision.prototype import features ...@@ -6,16 +6,12 @@ from torchvision.prototype import features
from torchvision.transforms import functional_tensor as _FT from torchvision.transforms import functional_tensor as _FT
from torchvision.transforms.functional import pil_to_tensor, to_pil_image from torchvision.transforms.functional import pil_to_tensor, to_pil_image
# Due to torch.jit.script limitation we keep TensorImageType as torch.Tensor
# instead of Union[torch.Tensor, features.Image]
TensorImageType = torch.Tensor
normalize_image_tensor = _FT.normalize normalize_image_tensor = _FT.normalize
def normalize(inpt: TensorImageType, mean: List[float], std: List[float], inplace: bool = False) -> torch.Tensor: def normalize(
inpt: features.TensorImageTypeJIT, mean: List[float], std: List[float], inplace: bool = False
) -> torch.Tensor:
if not isinstance(inpt, torch.Tensor): if not isinstance(inpt, torch.Tensor):
raise TypeError(f"img should be Tensor Image. Got {type(inpt)}") raise TypeError(f"img should be Tensor Image. Got {type(inpt)}")
else: else:
...@@ -62,7 +58,9 @@ def gaussian_blur_image_pil( ...@@ -62,7 +58,9 @@ def gaussian_blur_image_pil(
return to_pil_image(output, mode=img.mode) return to_pil_image(output, mode=img.mode)
def gaussian_blur(inpt: features.DType, kernel_size: List[int], sigma: Optional[List[float]] = None) -> features.DType: def gaussian_blur(
inpt: features.InputTypeJIT, kernel_size: List[int], sigma: Optional[List[float]] = None
) -> features.InputTypeJIT:
if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, features._Feature)): if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, features._Feature)):
return gaussian_blur_image_tensor(inpt, kernel_size=kernel_size, sigma=sigma) return gaussian_blur_image_tensor(inpt, kernel_size=kernel_size, sigma=sigma)
elif isinstance(inpt, features._Feature): elif isinstance(inpt, features._Feature):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment