Unverified Commit 841b9a19 authored by Vasilis Vryniotis's avatar Vasilis Vryniotis Committed by GitHub
Browse files

Make prototype `F` JIT-scriptable (#6584)



* Improve existing low kernel test.

* Add new midlevel jit-scriptability test (failing).

* Remove duplicate aliases from kernel tests.

* Fixing colour kernels.

* Fixing deprecated kernels.

* fix mypy

* Silence mypy instead of fixing to avoid performance penalty

* Fixing augment kernels.

* Fixing augment meta.

* Remove is_tracing calls.

* Add fake ImageType and DType

* Fixing type conversion kernels.

* Fixing misc kernels.

* partial fix geometry

* Remove mutable default from `_pad_with_vector_fill()` + all other unnecessary defaults.

* Fix geometry ops

* Fixing tests

* Removed xfail for jit tests on midlevel ops
Co-authored-by: default avatarvfdev-5 <vfdev.5@gmail.com>
parent 3a1f05ed
...@@ -113,14 +113,21 @@ DISPATCHER_INFOS = [ ...@@ -113,14 +113,21 @@ DISPATCHER_INFOS = [
features.Mask: F.pad_mask, features.Mask: F.pad_mask,
}, },
), ),
DispatcherInfo( # FIXME:
F.perspective, # RuntimeError: perspective() is missing value for argument 'startpoints'.
kernels={ # Declaration: perspective(Tensor inpt, int[][] startpoints, int[][] endpoints,
features.Image: F.perspective_image_tensor, # Enum<__torch__.torchvision.transforms.functional.InterpolationMode> interpolation=Enum<InterpolationMode.BILINEAR>,
features.BoundingBox: F.perspective_bounding_box, # Union(float[], float, int, NoneType) fill=None) -> Tensor
features.Mask: F.perspective_mask, #
}, # This is probably due to the fact that F.perspective does not have the same signature as F.perspective_image_tensor
), # DispatcherInfo(
# F.perspective,
# kernels={
# features.Image: F.perspective_image_tensor,
# features.BoundingBox: F.perspective_bounding_box,
# features.Mask: F.perspective_mask,
# },
# ),
DispatcherInfo( DispatcherInfo(
F.center_crop, F.center_crop,
kernels={ kernels={
......
...@@ -376,6 +376,9 @@ class TestPad: ...@@ -376,6 +376,9 @@ class TestPad:
inpt = mocker.MagicMock(spec=features.Image) inpt = mocker.MagicMock(spec=features.Image)
_ = transform(inpt) _ = transform(inpt)
fill = transforms.functional._geometry._convert_fill_arg(fill)
if isinstance(padding, tuple):
padding = list(padding)
fn.assert_called_once_with(inpt, padding=padding, fill=fill, padding_mode=padding_mode) fn.assert_called_once_with(inpt, padding=padding, fill=fill, padding_mode=padding_mode)
@pytest.mark.parametrize("fill", [12, {features.Image: 12, features.Mask: 34}]) @pytest.mark.parametrize("fill", [12, {features.Image: 12, features.Mask: 34}])
...@@ -389,14 +392,17 @@ class TestPad: ...@@ -389,14 +392,17 @@ class TestPad:
_ = transform(inpt) _ = transform(inpt)
if isinstance(fill, int): if isinstance(fill, int):
fill = transforms.functional._geometry._convert_fill_arg(fill)
calls = [ calls = [
mocker.call(image, padding=1, fill=fill, padding_mode="constant"), mocker.call(image, padding=1, fill=fill, padding_mode="constant"),
mocker.call(mask, padding=1, fill=fill, padding_mode="constant"), mocker.call(mask, padding=1, fill=fill, padding_mode="constant"),
] ]
else: else:
fill_img = transforms.functional._geometry._convert_fill_arg(fill[type(image)])
fill_mask = transforms.functional._geometry._convert_fill_arg(fill[type(mask)])
calls = [ calls = [
mocker.call(image, padding=1, fill=fill[type(image)], padding_mode="constant"), mocker.call(image, padding=1, fill=fill_img, padding_mode="constant"),
mocker.call(mask, padding=1, fill=fill[type(mask)], padding_mode="constant"), mocker.call(mask, padding=1, fill=fill_mask, padding_mode="constant"),
] ]
fn.assert_has_calls(calls) fn.assert_has_calls(calls)
...@@ -447,6 +453,7 @@ class TestRandomZoomOut: ...@@ -447,6 +453,7 @@ class TestRandomZoomOut:
torch.rand(1) # random apply changes random state torch.rand(1) # random apply changes random state
params = transform._get_params(inpt) params = transform._get_params(inpt)
fill = transforms.functional._geometry._convert_fill_arg(fill)
fn.assert_called_once_with(inpt, **params, fill=fill) fn.assert_called_once_with(inpt, **params, fill=fill)
@pytest.mark.parametrize("fill", [12, {features.Image: 12, features.Mask: 34}]) @pytest.mark.parametrize("fill", [12, {features.Image: 12, features.Mask: 34}])
...@@ -465,14 +472,17 @@ class TestRandomZoomOut: ...@@ -465,14 +472,17 @@ class TestRandomZoomOut:
params = transform._get_params(inpt) params = transform._get_params(inpt)
if isinstance(fill, int): if isinstance(fill, int):
fill = transforms.functional._geometry._convert_fill_arg(fill)
calls = [ calls = [
mocker.call(image, **params, fill=fill), mocker.call(image, **params, fill=fill),
mocker.call(mask, **params, fill=fill), mocker.call(mask, **params, fill=fill),
] ]
else: else:
fill_img = transforms.functional._geometry._convert_fill_arg(fill[type(image)])
fill_mask = transforms.functional._geometry._convert_fill_arg(fill[type(mask)])
calls = [ calls = [
mocker.call(image, **params, fill=fill[type(image)]), mocker.call(image, **params, fill=fill_img),
mocker.call(mask, **params, fill=fill[type(mask)]), mocker.call(mask, **params, fill=fill_mask),
] ]
fn.assert_has_calls(calls) fn.assert_has_calls(calls)
...@@ -533,6 +543,7 @@ class TestRandomRotation: ...@@ -533,6 +543,7 @@ class TestRandomRotation:
torch.manual_seed(12) torch.manual_seed(12)
params = transform._get_params(inpt) params = transform._get_params(inpt)
fill = transforms.functional._geometry._convert_fill_arg(fill)
fn.assert_called_once_with(inpt, **params, interpolation=interpolation, expand=expand, fill=fill, center=center) fn.assert_called_once_with(inpt, **params, interpolation=interpolation, expand=expand, fill=fill, center=center)
@pytest.mark.parametrize("angle", [34, -87]) @pytest.mark.parametrize("angle", [34, -87])
...@@ -670,6 +681,7 @@ class TestRandomAffine: ...@@ -670,6 +681,7 @@ class TestRandomAffine:
torch.manual_seed(12) torch.manual_seed(12)
params = transform._get_params(inpt) params = transform._get_params(inpt)
fill = transforms.functional._geometry._convert_fill_arg(fill)
fn.assert_called_once_with(inpt, **params, interpolation=interpolation, fill=fill, center=center) fn.assert_called_once_with(inpt, **params, interpolation=interpolation, fill=fill, center=center)
...@@ -917,6 +929,7 @@ class TestRandomPerspective: ...@@ -917,6 +929,7 @@ class TestRandomPerspective:
torch.rand(1) # random apply changes random state torch.rand(1) # random apply changes random state
params = transform._get_params(inpt) params = transform._get_params(inpt)
fill = transforms.functional._geometry._convert_fill_arg(fill)
fn.assert_called_once_with(inpt, **params, fill=fill, interpolation=interpolation) fn.assert_called_once_with(inpt, **params, fill=fill, interpolation=interpolation)
...@@ -986,6 +999,7 @@ class TestElasticTransform: ...@@ -986,6 +999,7 @@ class TestElasticTransform:
transform._get_params = mocker.MagicMock() transform._get_params = mocker.MagicMock()
_ = transform(inpt) _ = transform(inpt)
params = transform._get_params(inpt) params = transform._get_params(inpt)
fill = transforms.functional._geometry._convert_fill_arg(fill)
fn.assert_called_once_with(inpt, **params, fill=fill, interpolation=interpolation) fn.assert_called_once_with(inpt, **params, fill=fill, interpolation=interpolation)
...@@ -1609,6 +1623,7 @@ class TestFixedSizeCrop: ...@@ -1609,6 +1623,7 @@ class TestFixedSizeCrop:
if not needs_crop: if not needs_crop:
assert args[0] is inpt_sentinel assert args[0] is inpt_sentinel
assert args[1] is padding_sentinel assert args[1] is padding_sentinel
fill_sentinel = transforms.functional._geometry._convert_fill_arg(fill_sentinel)
assert kwargs == dict(fill=fill_sentinel, padding_mode=padding_mode_sentinel) assert kwargs == dict(fill=fill_sentinel, padding_mode=padding_mode_sentinel)
else: else:
mock_pad.assert_not_called() mock_pad.assert_not_called()
......
...@@ -9,7 +9,6 @@ from torchvision.prototype import features ...@@ -9,7 +9,6 @@ from torchvision.prototype import features
class TestCommon: class TestCommon:
@pytest.mark.xfail(reason="dispatchers are currently not scriptable")
@pytest.mark.parametrize( @pytest.mark.parametrize(
("info", "args_kwargs"), ("info", "args_kwargs"),
[ [
......
...@@ -407,27 +407,74 @@ def erase_image_tensor(): ...@@ -407,27 +407,74 @@ def erase_image_tensor():
yield ArgsKwargs(image, i=1, j=2, h=6, w=7, v=torch.rand(c, 6, 7)) yield ArgsKwargs(image, i=1, j=2, h=6, w=7, v=torch.rand(c, 6, 7))
_KERNEL_TYPES = {"_image_tensor", "_image_pil", "_mask", "_bounding_box", "_label"}
def _distinct_callables(callable_names):
# Ensure we deduplicate callables (due to aliases) without losing the names on the new API
remove = set()
distinct = set()
for name in callable_names:
item = F.__dict__[name]
if item not in distinct:
distinct.add(item)
else:
remove.add(name)
callable_names -= remove
# create tuple and sort by name
return sorted([(name, F.__dict__[name]) for name in callable_names], key=lambda t: t[0])
def _get_distinct_kernels():
kernel_names = {
name
for name, f in F.__dict__.items()
if callable(f) and not name.startswith("_") and any(name.endswith(k) for k in _KERNEL_TYPES)
}
return _distinct_callables(kernel_names)
def _get_distinct_midlevels():
midlevel_names = {
name
for name, f in F.__dict__.items()
if callable(f) and not name.startswith("_") and not any(name.endswith(k) for k in _KERNEL_TYPES)
}
return _distinct_callables(midlevel_names)
@pytest.mark.parametrize( @pytest.mark.parametrize(
"kernel", "kernel",
[ [
pytest.param(kernel, id=name) pytest.param(kernel, id=name)
for name, kernel in F.__dict__.items() for name, kernel in _get_distinct_kernels()
if not name.startswith("_") if not name.endswith("_image_pil") and name not in {"to_image_tensor"}
and callable(kernel) ],
and any(feature_type in name for feature_type in {"image", "mask", "bounding_box", "label"}) )
and "pil" not in name def test_scriptable_kernel(kernel):
and name jit.script(kernel) # TODO: pass data through it
@pytest.mark.parametrize(
"midlevel",
[
pytest.param(midlevel, id=name)
for name, midlevel in _get_distinct_midlevels()
if name
not in { not in {
"to_image_tensor", "InterpolationMode",
"get_num_channels", "decode_image_with_pil",
"get_spatial_size", "decode_video_with_av",
"get_image_num_channels", "pil_to_tensor",
"get_image_size", "to_grayscale",
"to_pil_image",
"to_tensor",
} }
], ],
) )
def test_scriptable(kernel): def test_scriptable_midlevel(midlevel):
jit.script(kernel) jit.script(midlevel) # TODO: pass data through it
# Test below is intended to test mid-level op vs low-level ops it calls # Test below is intended to test mid-level op vs low-level ops it calls
...@@ -439,8 +486,8 @@ def test_scriptable(kernel): ...@@ -439,8 +486,8 @@ def test_scriptable(kernel):
[ [
pytest.param(func, id=name) pytest.param(func, id=name)
for name, func in F.__dict__.items() for name, func in F.__dict__.items()
if not name.startswith("_") if not name.startswith("_") and callable(func)
and callable(func) # TODO: remove aliases
and all(feature_type not in name for feature_type in {"image", "mask", "bounding_box", "label", "pil"}) and all(feature_type not in name for feature_type in {"image", "mask", "bounding_box", "label", "pil"})
and name and name
not in { not in {
......
from ._bounding_box import BoundingBox, BoundingBoxFormat from ._bounding_box import BoundingBox, BoundingBoxFormat
from ._encoded import EncodedData, EncodedImage, EncodedVideo from ._encoded import EncodedData, EncodedImage, EncodedVideo
from ._feature import _Feature, is_simple_tensor from ._feature import _Feature, DType, is_simple_tensor
from ._image import ColorSpace, Image from ._image import ColorSpace, Image, ImageType
from ._label import Label, OneHotLabel from ._label import Label, OneHotLabel
from ._mask import Mask from ._mask import Mask
...@@ -10,6 +10,11 @@ from torchvision.transforms import InterpolationMode ...@@ -10,6 +10,11 @@ from torchvision.transforms import InterpolationMode
F = TypeVar("F", bound="_Feature") F = TypeVar("F", bound="_Feature")
# Due to torch.jit.script limitation we keep DType as torch.Tensor
# instead of Union[torch.Tensor, PIL.Image.Image, features._Feature]
DType = torch.Tensor
def is_simple_tensor(inpt: Any) -> bool: def is_simple_tensor(inpt: Any) -> bool:
return isinstance(inpt, torch.Tensor) and not isinstance(inpt, _Feature) return isinstance(inpt, torch.Tensor) and not isinstance(inpt, _Feature)
......
...@@ -12,6 +12,11 @@ from ._bounding_box import BoundingBox ...@@ -12,6 +12,11 @@ from ._bounding_box import BoundingBox
from ._feature import _Feature from ._feature import _Feature
# Due to torch.jit.script limitation we keep ImageType as torch.Tensor
# instead of Union[torch.Tensor, PIL.Image.Image, features.Image]
ImageType = torch.Tensor
class ColorSpace(StrEnum): class ColorSpace(StrEnum):
OTHER = StrEnum.auto() OTHER = StrEnum.auto()
GRAY = StrEnum.auto() GRAY = StrEnum.auto()
...@@ -32,6 +37,31 @@ class ColorSpace(StrEnum): ...@@ -32,6 +37,31 @@ class ColorSpace(StrEnum):
else: else:
return cls.OTHER return cls.OTHER
@staticmethod
def from_tensor_shape(shape: List[int]) -> ColorSpace:
return _from_tensor_shape(shape)
def _from_tensor_shape(shape: List[int]) -> ColorSpace:
# Needed as a standalone method for JIT
ndim = len(shape)
if ndim < 2:
return ColorSpace.OTHER
elif ndim == 2:
return ColorSpace.GRAY
num_channels = shape[-3]
if num_channels == 1:
return ColorSpace.GRAY
elif num_channels == 2:
return ColorSpace.GRAY_ALPHA
elif num_channels == 3:
return ColorSpace.RGB
elif num_channels == 4:
return ColorSpace.RGB_ALPHA
else:
return ColorSpace.OTHER
class Image(_Feature): class Image(_Feature):
color_space: ColorSpace color_space: ColorSpace
...@@ -53,7 +83,7 @@ class Image(_Feature): ...@@ -53,7 +83,7 @@ class Image(_Feature):
image = super().__new__(cls, data, requires_grad=requires_grad) image = super().__new__(cls, data, requires_grad=requires_grad)
if color_space is None: if color_space is None:
color_space = cls.guess_color_space(image) color_space = ColorSpace.from_tensor_shape(image.shape) # type: ignore[arg-type]
if color_space == ColorSpace.OTHER: if color_space == ColorSpace.OTHER:
warnings.warn("Unable to guess a specific color space. Consider passing it explicitly.") warnings.warn("Unable to guess a specific color space. Consider passing it explicitly.")
elif isinstance(color_space, str): elif isinstance(color_space, str):
...@@ -83,25 +113,6 @@ class Image(_Feature): ...@@ -83,25 +113,6 @@ class Image(_Feature):
def num_channels(self) -> int: def num_channels(self) -> int:
return self.shape[-3] return self.shape[-3]
@staticmethod
def guess_color_space(data: torch.Tensor) -> ColorSpace:
if data.ndim < 2:
return ColorSpace.OTHER
elif data.ndim == 2:
return ColorSpace.GRAY
num_channels = data.shape[-3]
if num_channels == 1:
return ColorSpace.GRAY
elif num_channels == 2:
return ColorSpace.GRAY_ALPHA
elif num_channels == 3:
return ColorSpace.RGB
elif num_channels == 4:
return ColorSpace.RGB_ALPHA
else:
return ColorSpace.OTHER
def to_color_space(self, color_space: Union[str, ColorSpace], copy: bool = True) -> Image: def to_color_space(self, color_space: Union[str, ColorSpace], copy: bool = True) -> Image:
if isinstance(color_space, str): if isinstance(color_space, str):
color_space = ColorSpace.from_str(color_space.upper()) color_space = ColorSpace.from_str(color_space.upper())
......
...@@ -72,11 +72,10 @@ class _AutoAugmentBase(Transform): ...@@ -72,11 +72,10 @@ class _AutoAugmentBase(Transform):
# Fill = 0 is not equivalent to None, https://github.com/pytorch/vision/issues/6517 # Fill = 0 is not equivalent to None, https://github.com/pytorch/vision/issues/6517
# So, we have to put fill as None if fill == 0 # So, we have to put fill as None if fill == 0
fill_: Optional[Union[int, float, Sequence[int], Sequence[float]]] # This is due to BC with stable API which has fill = None by default
fill_ = F._geometry._convert_fill_arg(fill)
if isinstance(fill, int) and fill == 0: if isinstance(fill, int) and fill == 0:
fill_ = None fill_ = None
else:
fill_ = fill
if transform_id == "Identity": if transform_id == "Identity":
return image return image
......
...@@ -252,7 +252,14 @@ class Pad(Transform): ...@@ -252,7 +252,14 @@ class Pad(Transform):
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
fill = self.fill[type(inpt)] fill = self.fill[type(inpt)]
return F.pad(inpt, padding=self.padding, fill=fill, padding_mode=self.padding_mode)
# This cast does Sequence[int] -> List[int] and is required to make mypy happy
padding = self.padding
if not isinstance(padding, int):
padding = list(padding)
fill = F._geometry._convert_fill_arg(fill)
return F.pad(inpt, padding=padding, fill=fill, padding_mode=self.padding_mode)
class RandomZoomOut(_RandomApplyTransform): class RandomZoomOut(_RandomApplyTransform):
...@@ -290,6 +297,7 @@ class RandomZoomOut(_RandomApplyTransform): ...@@ -290,6 +297,7 @@ class RandomZoomOut(_RandomApplyTransform):
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
fill = self.fill[type(inpt)] fill = self.fill[type(inpt)]
fill = F._geometry._convert_fill_arg(fill)
return F.pad(inpt, **params, fill=fill) return F.pad(inpt, **params, fill=fill)
...@@ -320,6 +328,7 @@ class RandomRotation(Transform): ...@@ -320,6 +328,7 @@ class RandomRotation(Transform):
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
fill = self.fill[type(inpt)] fill = self.fill[type(inpt)]
fill = F._geometry._convert_fill_arg(fill)
return F.rotate( return F.rotate(
inpt, inpt,
**params, **params,
...@@ -401,6 +410,7 @@ class RandomAffine(Transform): ...@@ -401,6 +410,7 @@ class RandomAffine(Transform):
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
fill = self.fill[type(inpt)] fill = self.fill[type(inpt)]
fill = F._geometry._convert_fill_arg(fill)
return F.affine( return F.affine(
inpt, inpt,
**params, **params,
...@@ -480,8 +490,15 @@ class RandomCrop(Transform): ...@@ -480,8 +490,15 @@ class RandomCrop(Transform):
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
# TODO: (PERF) check for speed optimization if we avoid repeated pad calls # TODO: (PERF) check for speed optimization if we avoid repeated pad calls
fill = self.fill[type(inpt)] fill = self.fill[type(inpt)]
fill = F._geometry._convert_fill_arg(fill)
if self.padding is not None: if self.padding is not None:
inpt = F.pad(inpt, padding=self.padding, fill=fill, padding_mode=self.padding_mode) # This cast does Sequence[int] -> List[int] and is required to make mypy happy
padding = self.padding
if not isinstance(padding, int):
padding = list(padding)
inpt = F.pad(inpt, padding=padding, fill=fill, padding_mode=self.padding_mode)
if self.pad_if_needed: if self.pad_if_needed:
input_width, input_height = params["input_width"], params["input_height"] input_width, input_height = params["input_width"], params["input_height"]
...@@ -543,6 +560,7 @@ class RandomPerspective(_RandomApplyTransform): ...@@ -543,6 +560,7 @@ class RandomPerspective(_RandomApplyTransform):
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
fill = self.fill[type(inpt)] fill = self.fill[type(inpt)]
fill = F._geometry._convert_fill_arg(fill)
return F.perspective( return F.perspective(
inpt, inpt,
**params, **params,
...@@ -610,6 +628,7 @@ class ElasticTransform(Transform): ...@@ -610,6 +628,7 @@ class ElasticTransform(Transform):
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
fill = self.fill[type(inpt)] fill = self.fill[type(inpt)]
fill = F._geometry._convert_fill_arg(fill)
return F.elastic( return F.elastic(
inpt, inpt,
**params, **params,
...@@ -868,6 +887,7 @@ class FixedSizeCrop(Transform): ...@@ -868,6 +887,7 @@ class FixedSizeCrop(Transform):
if params["needs_pad"]: if params["needs_pad"]:
fill = self.fill[type(inpt)] fill = self.fill[type(inpt)]
fill = F._geometry._convert_fill_arg(fill)
inpt = F.pad(inpt, params["padding"], fill=fill, padding_mode=self.padding_mode) inpt = F.pad(inpt, params["padding"], fill=fill, padding_mode=self.padding_mode)
return inpt return inpt
......
from typing import Any, Dict, Optional, Union from typing import Any, cast, Dict, Optional, Union
import numpy as np import numpy as np
import PIL.Image import PIL.Image
...@@ -13,7 +13,7 @@ class DecodeImage(Transform): ...@@ -13,7 +13,7 @@ class DecodeImage(Transform):
_transformed_types = (features.EncodedImage,) _transformed_types = (features.EncodedImage,)
def _transform(self, inpt: torch.Tensor, params: Dict[str, Any]) -> features.Image: def _transform(self, inpt: torch.Tensor, params: Dict[str, Any]) -> features.Image:
return F.decode_image_with_pil(inpt) return cast(features.Image, F.decode_image_with_pil(inpt))
class LabelToOneHot(Transform): class LabelToOneHot(Transform):
...@@ -50,7 +50,7 @@ class ToImageTensor(Transform): ...@@ -50,7 +50,7 @@ class ToImageTensor(Transform):
def _transform( def _transform(
self, inpt: Union[torch.Tensor, PIL.Image.Image, np.ndarray], params: Dict[str, Any] self, inpt: Union[torch.Tensor, PIL.Image.Image, np.ndarray], params: Dict[str, Any]
) -> features.Image: ) -> features.Image:
return F.to_image_tensor(inpt) return cast(features.Image, F.to_image_tensor(inpt))
class ToImagePIL(Transform): class ToImagePIL(Transform):
......
...@@ -74,7 +74,7 @@ from ._geometry import ( ...@@ -74,7 +74,7 @@ from ._geometry import (
five_crop, five_crop,
five_crop_image_pil, five_crop_image_pil,
five_crop_image_tensor, five_crop_image_tensor,
hflip, hflip, # TODO: Consider moving all pure alias definitions at the bottom of the file
horizontal_flip, horizontal_flip,
horizontal_flip_bounding_box, horizontal_flip_bounding_box,
horizontal_flip_image_pil, horizontal_flip_image_pil,
......
from typing import Union
import PIL.Image import PIL.Image
import torch import torch
...@@ -11,6 +9,7 @@ from torchvision.transforms.functional import pil_to_tensor, to_pil_image ...@@ -11,6 +9,7 @@ from torchvision.transforms.functional import pil_to_tensor, to_pil_image
erase_image_tensor = _FT.erase erase_image_tensor = _FT.erase
@torch.jit.unused
def erase_image_pil( def erase_image_pil(
img: PIL.Image.Image, i: int, j: int, h: int, w: int, v: torch.Tensor, inplace: bool = False img: PIL.Image.Image, i: int, j: int, h: int, w: int, v: torch.Tensor, inplace: bool = False
) -> PIL.Image.Image: ) -> PIL.Image.Image:
...@@ -20,17 +19,17 @@ def erase_image_pil( ...@@ -20,17 +19,17 @@ def erase_image_pil(
def erase( def erase(
inpt: Union[torch.Tensor, PIL.Image.Image, features.Image], inpt: features.ImageType,
i: int, i: int,
j: int, j: int,
h: int, h: int,
w: int, w: int,
v: torch.Tensor, v: torch.Tensor,
inplace: bool = False, inplace: bool = False,
) -> Union[torch.Tensor, PIL.Image.Image, features.Image]: ) -> features.ImageType:
if isinstance(inpt, torch.Tensor): if isinstance(inpt, torch.Tensor):
output = erase_image_tensor(inpt, i=i, j=j, h=h, w=w, v=v, inplace=inplace) output = erase_image_tensor(inpt, i=i, j=j, h=h, w=w, v=v, inplace=inplace)
if isinstance(inpt, features.Image): if not torch.jit.is_scripting() and isinstance(inpt, features.Image):
output = features.Image.new_like(inpt, output) output = features.Image.new_like(inpt, output)
return output return output
else: # isinstance(inpt, PIL.Image.Image): else: # isinstance(inpt, PIL.Image.Image):
......
from typing import Union
import PIL.Image
import torch import torch
from torchvision.prototype import features from torchvision.prototype import features
from torchvision.transforms import functional_pil as _FP, functional_tensor as _FT from torchvision.transforms import functional_pil as _FP, functional_tensor as _FT
# shortcut type
DType = Union[torch.Tensor, PIL.Image.Image, features._Feature]
adjust_brightness_image_tensor = _FT.adjust_brightness adjust_brightness_image_tensor = _FT.adjust_brightness
adjust_brightness_image_pil = _FP.adjust_brightness adjust_brightness_image_pil = _FP.adjust_brightness
def adjust_brightness(inpt: DType, brightness_factor: float) -> DType: def adjust_brightness(inpt: features.DType, brightness_factor: float) -> features.DType:
if isinstance(inpt, features._Feature): if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, features._Feature)):
return adjust_brightness_image_tensor(inpt, brightness_factor=brightness_factor)
elif isinstance(inpt, features._Feature):
return inpt.adjust_brightness(brightness_factor=brightness_factor) return inpt.adjust_brightness(brightness_factor=brightness_factor)
elif isinstance(inpt, PIL.Image.Image):
return adjust_brightness_image_pil(inpt, brightness_factor=brightness_factor)
else: else:
return adjust_brightness_image_tensor(inpt, brightness_factor=brightness_factor) return adjust_brightness_image_pil(inpt, brightness_factor=brightness_factor)
adjust_saturation_image_tensor = _FT.adjust_saturation adjust_saturation_image_tensor = _FT.adjust_saturation
adjust_saturation_image_pil = _FP.adjust_saturation adjust_saturation_image_pil = _FP.adjust_saturation
def adjust_saturation(inpt: DType, saturation_factor: float) -> DType: def adjust_saturation(inpt: features.DType, saturation_factor: float) -> features.DType:
if isinstance(inpt, features._Feature): if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, features._Feature)):
return adjust_saturation_image_tensor(inpt, saturation_factor=saturation_factor)
elif isinstance(inpt, features._Feature):
return inpt.adjust_saturation(saturation_factor=saturation_factor) return inpt.adjust_saturation(saturation_factor=saturation_factor)
elif isinstance(inpt, PIL.Image.Image):
return adjust_saturation_image_pil(inpt, saturation_factor=saturation_factor)
else: else:
return adjust_saturation_image_tensor(inpt, saturation_factor=saturation_factor) return adjust_saturation_image_pil(inpt, saturation_factor=saturation_factor)
adjust_contrast_image_tensor = _FT.adjust_contrast adjust_contrast_image_tensor = _FT.adjust_contrast
adjust_contrast_image_pil = _FP.adjust_contrast adjust_contrast_image_pil = _FP.adjust_contrast
def adjust_contrast(inpt: DType, contrast_factor: float) -> DType: def adjust_contrast(inpt: features.DType, contrast_factor: float) -> features.DType:
if isinstance(inpt, features._Feature): if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, features._Feature)):
return adjust_contrast_image_tensor(inpt, contrast_factor=contrast_factor)
elif isinstance(inpt, features._Feature):
return inpt.adjust_contrast(contrast_factor=contrast_factor) return inpt.adjust_contrast(contrast_factor=contrast_factor)
elif isinstance(inpt, PIL.Image.Image):
return adjust_contrast_image_pil(inpt, contrast_factor=contrast_factor)
else: else:
return adjust_contrast_image_tensor(inpt, contrast_factor=contrast_factor) return adjust_contrast_image_pil(inpt, contrast_factor=contrast_factor)
adjust_sharpness_image_tensor = _FT.adjust_sharpness adjust_sharpness_image_tensor = _FT.adjust_sharpness
adjust_sharpness_image_pil = _FP.adjust_sharpness adjust_sharpness_image_pil = _FP.adjust_sharpness
def adjust_sharpness(inpt: DType, sharpness_factor: float) -> DType: def adjust_sharpness(inpt: features.DType, sharpness_factor: float) -> features.DType:
if isinstance(inpt, features._Feature): if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, features._Feature)):
return adjust_sharpness_image_tensor(inpt, sharpness_factor=sharpness_factor)
elif isinstance(inpt, features._Feature):
return inpt.adjust_sharpness(sharpness_factor=sharpness_factor) return inpt.adjust_sharpness(sharpness_factor=sharpness_factor)
elif isinstance(inpt, PIL.Image.Image):
return adjust_sharpness_image_pil(inpt, sharpness_factor=sharpness_factor)
else: else:
return adjust_sharpness_image_tensor(inpt, sharpness_factor=sharpness_factor) return adjust_sharpness_image_pil(inpt, sharpness_factor=sharpness_factor)
adjust_hue_image_tensor = _FT.adjust_hue adjust_hue_image_tensor = _FT.adjust_hue
adjust_hue_image_pil = _FP.adjust_hue adjust_hue_image_pil = _FP.adjust_hue
def adjust_hue(inpt: DType, hue_factor: float) -> DType: def adjust_hue(inpt: features.DType, hue_factor: float) -> features.DType:
if isinstance(inpt, features._Feature): if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, features._Feature)):
return adjust_hue_image_tensor(inpt, hue_factor=hue_factor)
elif isinstance(inpt, features._Feature):
return inpt.adjust_hue(hue_factor=hue_factor) return inpt.adjust_hue(hue_factor=hue_factor)
elif isinstance(inpt, PIL.Image.Image):
return adjust_hue_image_pil(inpt, hue_factor=hue_factor)
else: else:
return adjust_hue_image_tensor(inpt, hue_factor=hue_factor) return adjust_hue_image_pil(inpt, hue_factor=hue_factor)
adjust_gamma_image_tensor = _FT.adjust_gamma adjust_gamma_image_tensor = _FT.adjust_gamma
adjust_gamma_image_pil = _FP.adjust_gamma adjust_gamma_image_pil = _FP.adjust_gamma
def adjust_gamma(inpt: DType, gamma: float, gain: float = 1) -> DType: def adjust_gamma(inpt: features.DType, gamma: float, gain: float = 1) -> features.DType:
if isinstance(inpt, features._Feature): if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, features._Feature)):
return adjust_gamma_image_tensor(inpt, gamma=gamma, gain=gain)
elif isinstance(inpt, features._Feature):
return inpt.adjust_gamma(gamma=gamma, gain=gain) return inpt.adjust_gamma(gamma=gamma, gain=gain)
elif isinstance(inpt, PIL.Image.Image):
return adjust_gamma_image_pil(inpt, gamma=gamma, gain=gain)
else: else:
return adjust_gamma_image_tensor(inpt, gamma=gamma, gain=gain) return adjust_gamma_image_pil(inpt, gamma=gamma, gain=gain)
posterize_image_tensor = _FT.posterize posterize_image_tensor = _FT.posterize
posterize_image_pil = _FP.posterize posterize_image_pil = _FP.posterize
def posterize(inpt: DType, bits: int) -> DType: def posterize(inpt: features.DType, bits: int) -> features.DType:
if isinstance(inpt, features._Feature): if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, features._Feature)):
return posterize_image_tensor(inpt, bits=bits)
elif isinstance(inpt, features._Feature):
return inpt.posterize(bits=bits) return inpt.posterize(bits=bits)
elif isinstance(inpt, PIL.Image.Image):
return posterize_image_pil(inpt, bits=bits)
else: else:
return posterize_image_tensor(inpt, bits=bits) return posterize_image_pil(inpt, bits=bits)
solarize_image_tensor = _FT.solarize solarize_image_tensor = _FT.solarize
solarize_image_pil = _FP.solarize solarize_image_pil = _FP.solarize
def solarize(inpt: DType, threshold: float) -> DType: def solarize(inpt: features.DType, threshold: float) -> features.DType:
if isinstance(inpt, features._Feature): if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, features._Feature)):
return solarize_image_tensor(inpt, threshold=threshold)
elif isinstance(inpt, features._Feature):
return inpt.solarize(threshold=threshold) return inpt.solarize(threshold=threshold)
elif isinstance(inpt, PIL.Image.Image):
return solarize_image_pil(inpt, threshold=threshold)
else: else:
return solarize_image_tensor(inpt, threshold=threshold) return solarize_image_pil(inpt, threshold=threshold)
autocontrast_image_tensor = _FT.autocontrast autocontrast_image_tensor = _FT.autocontrast
autocontrast_image_pil = _FP.autocontrast autocontrast_image_pil = _FP.autocontrast
def autocontrast(inpt: DType) -> DType: def autocontrast(inpt: features.DType) -> features.DType:
if isinstance(inpt, features._Feature): if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, features._Feature)):
return autocontrast_image_tensor(inpt)
elif isinstance(inpt, features._Feature):
return inpt.autocontrast() return inpt.autocontrast()
elif isinstance(inpt, PIL.Image.Image):
return autocontrast_image_pil(inpt)
else: else:
return autocontrast_image_tensor(inpt) return autocontrast_image_pil(inpt)
equalize_image_tensor = _FT.equalize equalize_image_tensor = _FT.equalize
equalize_image_pil = _FP.equalize equalize_image_pil = _FP.equalize
def equalize(inpt: DType) -> DType: def equalize(inpt: features.DType) -> features.DType:
if isinstance(inpt, features._Feature): if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, features._Feature)):
return equalize_image_tensor(inpt)
elif isinstance(inpt, features._Feature):
return inpt.equalize() return inpt.equalize()
elif isinstance(inpt, PIL.Image.Image):
return equalize_image_pil(inpt)
else: else:
return equalize_image_tensor(inpt) return equalize_image_pil(inpt)
invert_image_tensor = _FT.invert invert_image_tensor = _FT.invert
invert_image_pil = _FP.invert invert_image_pil = _FP.invert
def invert(inpt: DType) -> DType: def invert(inpt: features.DType) -> features.DType:
if isinstance(inpt, features._Feature): if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, features._Feature)):
return invert_image_tensor(inpt)
elif isinstance(inpt, features._Feature):
return inpt.invert() return inpt.invert()
elif isinstance(inpt, PIL.Image.Image):
return invert_image_pil(inpt)
else: else:
return invert_image_tensor(inpt) return invert_image_pil(inpt)
import warnings import warnings
from typing import Any, List, Union from typing import Any, List
import PIL.Image import PIL.Image
import torch import torch
...@@ -8,6 +8,12 @@ from torchvision.prototype import features ...@@ -8,6 +8,12 @@ from torchvision.prototype import features
from torchvision.transforms import functional as _F from torchvision.transforms import functional as _F
# Due to torch.jit.script limitation we keep LegacyImageType as torch.Tensor
# instead of Union[torch.Tensor, PIL.Image.Image]
LegacyImageType = torch.Tensor
@torch.jit.unused
def to_grayscale(inpt: PIL.Image.Image, num_output_channels: int = 1) -> PIL.Image.Image: def to_grayscale(inpt: PIL.Image.Image, num_output_channels: int = 1) -> PIL.Image.Image:
call = ", num_output_channels=3" if num_output_channels == 3 else "" call = ", num_output_channels=3" if num_output_channels == 3 else ""
replacement = "convert_color_space(..., color_space=features.ColorSpace.GRAY)" replacement = "convert_color_space(..., color_space=features.ColorSpace.GRAY)"
...@@ -21,10 +27,12 @@ def to_grayscale(inpt: PIL.Image.Image, num_output_channels: int = 1) -> PIL.Ima ...@@ -21,10 +27,12 @@ def to_grayscale(inpt: PIL.Image.Image, num_output_channels: int = 1) -> PIL.Ima
return _F.to_grayscale(inpt, num_output_channels=num_output_channels) return _F.to_grayscale(inpt, num_output_channels=num_output_channels)
def rgb_to_grayscale( def rgb_to_grayscale(inpt: LegacyImageType, num_output_channels: int = 1) -> LegacyImageType:
inpt: Union[PIL.Image.Image, torch.Tensor], num_output_channels: int = 1 old_color_space = (
) -> Union[PIL.Image.Image, torch.Tensor]: features._image._from_tensor_shape(inpt.shape) # type: ignore[arg-type]
old_color_space = features.Image.guess_color_space(inpt) if features.is_simple_tensor(inpt) else None if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, features.Image))
else None
)
call = ", num_output_channels=3" if num_output_channels == 3 else "" call = ", num_output_channels=3" if num_output_channels == 3 else ""
replacement = ( replacement = (
...@@ -44,6 +52,7 @@ def rgb_to_grayscale( ...@@ -44,6 +52,7 @@ def rgb_to_grayscale(
return _F.rgb_to_grayscale(inpt, num_output_channels=num_output_channels) return _F.rgb_to_grayscale(inpt, num_output_channels=num_output_channels)
@torch.jit.unused
def to_tensor(inpt: Any) -> torch.Tensor: def to_tensor(inpt: Any) -> torch.Tensor:
warnings.warn( warnings.warn(
"The function `to_tensor(...)` is deprecated and will be removed in a future release. " "The function `to_tensor(...)` is deprecated and will be removed in a future release. "
...@@ -52,7 +61,7 @@ def to_tensor(inpt: Any) -> torch.Tensor: ...@@ -52,7 +61,7 @@ def to_tensor(inpt: Any) -> torch.Tensor:
return _F.to_tensor(inpt) return _F.to_tensor(inpt)
def get_image_size(inpt: Union[PIL.Image.Image, torch.Tensor, features.Image]) -> List[int]: def get_image_size(inpt: features.ImageType) -> List[int]:
warnings.warn( warnings.warn(
"The function `get_image_size(...)` is deprecated and will be removed in a future release. " "The function `get_image_size(...)` is deprecated and will be removed in a future release. "
"Instead, please use `get_spatial_size(...)` which returns `[h, w]` instead of `[w, h]`." "Instead, please use `get_spatial_size(...)` which returns `[h, w]` instead of `[w, h]`."
......
...@@ -20,10 +20,6 @@ from torchvision.transforms.functional_tensor import _parse_pad_padding ...@@ -20,10 +20,6 @@ from torchvision.transforms.functional_tensor import _parse_pad_padding
from ._meta import convert_format_bounding_box, get_dimensions_image_pil, get_dimensions_image_tensor from ._meta import convert_format_bounding_box, get_dimensions_image_pil, get_dimensions_image_tensor
# shortcut type
DType = Union[torch.Tensor, PIL.Image.Image, features._Feature]
horizontal_flip_image_tensor = _FT.hflip horizontal_flip_image_tensor = _FT.hflip
horizontal_flip_image_pil = _FP.hflip horizontal_flip_image_pil = _FP.hflip
...@@ -48,13 +44,13 @@ def horizontal_flip_bounding_box( ...@@ -48,13 +44,13 @@ def horizontal_flip_bounding_box(
).view(shape) ).view(shape)
def horizontal_flip(inpt: DType) -> DType: def horizontal_flip(inpt: features.DType) -> features.DType:
if isinstance(inpt, features._Feature): if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, features._Feature)):
return horizontal_flip_image_tensor(inpt)
elif isinstance(inpt, features._Feature):
return inpt.horizontal_flip() return inpt.horizontal_flip()
elif isinstance(inpt, PIL.Image.Image):
return horizontal_flip_image_pil(inpt)
else: else:
return horizontal_flip_image_tensor(inpt) return horizontal_flip_image_pil(inpt)
vertical_flip_image_tensor = _FT.vflip vertical_flip_image_tensor = _FT.vflip
...@@ -81,13 +77,13 @@ def vertical_flip_bounding_box( ...@@ -81,13 +77,13 @@ def vertical_flip_bounding_box(
).view(shape) ).view(shape)
def vertical_flip(inpt: DType) -> DType: def vertical_flip(inpt: features.DType) -> features.DType:
if isinstance(inpt, features._Feature): if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, features._Feature)):
return vertical_flip_image_tensor(inpt)
elif isinstance(inpt, features._Feature):
return inpt.vertical_flip() return inpt.vertical_flip()
elif isinstance(inpt, PIL.Image.Image):
return vertical_flip_image_pil(inpt)
else: else:
return vertical_flip_image_tensor(inpt) return vertical_flip_image_pil(inpt)
# We changed the names to align them with the transforms, i.e. `RandomHorizontalFlip`. Still, `hflip` and `vflip` are # We changed the names to align them with the transforms, i.e. `RandomHorizontalFlip`. Still, `hflip` and `vflip` are
...@@ -118,6 +114,7 @@ def resize_image_tensor( ...@@ -118,6 +114,7 @@ def resize_image_tensor(
return image.view(extra_dims + (num_channels, new_height, new_width)) return image.view(extra_dims + (num_channels, new_height, new_width))
@torch.jit.unused
def resize_image_pil( def resize_image_pil(
img: PIL.Image.Image, img: PIL.Image.Image,
size: Union[Sequence[int], int], size: Union[Sequence[int], int],
...@@ -157,22 +154,22 @@ def resize_bounding_box( ...@@ -157,22 +154,22 @@ def resize_bounding_box(
def resize( def resize(
inpt: DType, inpt: features.DType,
size: List[int], size: List[int],
interpolation: InterpolationMode = InterpolationMode.BILINEAR, interpolation: InterpolationMode = InterpolationMode.BILINEAR,
max_size: Optional[int] = None, max_size: Optional[int] = None,
antialias: Optional[bool] = None, antialias: Optional[bool] = None,
) -> DType: ) -> features.DType:
if isinstance(inpt, features._Feature): if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, features._Feature)):
antialias = False if antialias is None else antialias
return resize_image_tensor(inpt, size, interpolation=interpolation, max_size=max_size, antialias=antialias)
elif isinstance(inpt, features._Feature):
antialias = False if antialias is None else antialias antialias = False if antialias is None else antialias
return inpt.resize(size, interpolation=interpolation, max_size=max_size, antialias=antialias) return inpt.resize(size, interpolation=interpolation, max_size=max_size, antialias=antialias)
elif isinstance(inpt, PIL.Image.Image): else:
if antialias is not None and not antialias: if antialias is not None and not antialias:
warnings.warn("Anti-alias option is always applied for PIL Image input. Argument antialias is ignored.") warnings.warn("Anti-alias option is always applied for PIL Image input. Argument antialias is ignored.")
return resize_image_pil(inpt, size, interpolation=interpolation, max_size=max_size) return resize_image_pil(inpt, size, interpolation=interpolation, max_size=max_size)
else:
antialias = False if antialias is None else antialias
return resize_image_tensor(inpt, size, interpolation=interpolation, max_size=max_size, antialias=antialias)
def _affine_parse_args( def _affine_parse_args(
...@@ -256,6 +253,7 @@ def affine_image_tensor( ...@@ -256,6 +253,7 @@ def affine_image_tensor(
return output.view(extra_dims + (num_channels, height, width)) return output.view(extra_dims + (num_channels, height, width))
@torch.jit.unused
def affine_image_pil( def affine_image_pil(
img: PIL.Image.Image, img: PIL.Image.Image,
angle: float, angle: float,
...@@ -263,7 +261,7 @@ def affine_image_pil( ...@@ -263,7 +261,7 @@ def affine_image_pil(
scale: float, scale: float,
shear: List[float], shear: List[float],
interpolation: InterpolationMode = InterpolationMode.NEAREST, interpolation: InterpolationMode = InterpolationMode.NEAREST,
fill: Optional[Union[int, float, Sequence[int], Sequence[float]]] = None, fill: Optional[Union[int, float, List[float]]] = None,
center: Optional[List[float]] = None, center: Optional[List[float]] = None,
) -> PIL.Image.Image: ) -> PIL.Image.Image:
angle, translate, shear, center = _affine_parse_args(angle, translate, scale, shear, interpolation, center) angle, translate, shear, center = _affine_parse_args(angle, translate, scale, shear, interpolation, center)
...@@ -422,21 +420,17 @@ def _convert_fill_arg( ...@@ -422,21 +420,17 @@ def _convert_fill_arg(
def affine( def affine(
inpt: DType, inpt: features.DType,
angle: float, angle: float,
translate: List[float], translate: List[float],
scale: float, scale: float,
shear: List[float], shear: List[float],
interpolation: InterpolationMode = InterpolationMode.NEAREST, interpolation: InterpolationMode = InterpolationMode.NEAREST,
fill: Optional[Union[int, float, Sequence[int], Sequence[float]]] = None, fill: Optional[Union[int, float, List[float]]] = None,
center: Optional[List[float]] = None, center: Optional[List[float]] = None,
) -> DType: ) -> features.DType:
if isinstance(inpt, features._Feature): if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, features._Feature)):
return inpt.affine( return affine_image_tensor(
angle, translate=translate, scale=scale, shear=shear, interpolation=interpolation, fill=fill, center=center
)
elif isinstance(inpt, PIL.Image.Image):
return affine_image_pil(
inpt, inpt,
angle, angle,
translate=translate, translate=translate,
...@@ -446,10 +440,12 @@ def affine( ...@@ -446,10 +440,12 @@ def affine(
fill=fill, fill=fill,
center=center, center=center,
) )
elif isinstance(inpt, features._Feature):
return inpt.affine(
angle, translate=translate, scale=scale, shear=shear, interpolation=interpolation, fill=fill, center=center
)
else: else:
fill = _convert_fill_arg(fill) return affine_image_pil(
return affine_image_tensor(
inpt, inpt,
angle, angle,
translate=translate, translate=translate,
...@@ -499,12 +495,13 @@ def rotate_image_tensor( ...@@ -499,12 +495,13 @@ def rotate_image_tensor(
return img.view(extra_dims + (num_channels, new_height, new_width)) return img.view(extra_dims + (num_channels, new_height, new_width))
@torch.jit.unused
def rotate_image_pil( def rotate_image_pil(
img: PIL.Image.Image, img: PIL.Image.Image,
angle: float, angle: float,
interpolation: InterpolationMode = InterpolationMode.NEAREST, interpolation: InterpolationMode = InterpolationMode.NEAREST,
expand: bool = False, expand: bool = False,
fill: Optional[Union[int, float, Sequence[int], Sequence[float]]] = None, fill: Optional[Union[int, float, List[float]]] = None,
center: Optional[List[float]] = None, center: Optional[List[float]] = None,
) -> PIL.Image.Image: ) -> PIL.Image.Image:
if center is not None and expand: if center is not None and expand:
...@@ -567,21 +564,19 @@ def rotate_mask( ...@@ -567,21 +564,19 @@ def rotate_mask(
def rotate( def rotate(
inpt: DType, inpt: features.DType,
angle: float, angle: float,
interpolation: InterpolationMode = InterpolationMode.NEAREST, interpolation: InterpolationMode = InterpolationMode.NEAREST,
expand: bool = False, expand: bool = False,
fill: Optional[Union[int, float, Sequence[int], Sequence[float]]] = None, fill: Optional[Union[int, float, List[float]]] = None,
center: Optional[List[float]] = None, center: Optional[List[float]] = None,
) -> DType: ) -> features.DType:
if isinstance(inpt, features._Feature): if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, features._Feature)):
return rotate_image_tensor(inpt, angle, interpolation=interpolation, expand=expand, fill=fill, center=center)
elif isinstance(inpt, features._Feature):
return inpt.rotate(angle, interpolation=interpolation, expand=expand, fill=fill, center=center) return inpt.rotate(angle, interpolation=interpolation, expand=expand, fill=fill, center=center)
elif isinstance(inpt, PIL.Image.Image):
return rotate_image_pil(inpt, angle, interpolation=interpolation, expand=expand, fill=fill, center=center)
else: else:
fill = _convert_fill_arg(fill) return rotate_image_pil(inpt, angle, interpolation=interpolation, expand=expand, fill=fill, center=center)
return rotate_image_tensor(inpt, angle, interpolation=interpolation, expand=expand, fill=fill, center=center)
pad_image_pil = _FP.pad pad_image_pil = _FP.pad
...@@ -700,23 +695,18 @@ def pad_bounding_box( ...@@ -700,23 +695,18 @@ def pad_bounding_box(
def pad( def pad(
inpt: DType, inpt: features.DType,
padding: Union[int, Sequence[int]], padding: Union[int, List[int]],
fill: Optional[Union[int, float, Sequence[int], Sequence[float]]] = None, fill: Optional[Union[int, float, List[float]]] = None,
padding_mode: str = "constant", padding_mode: str = "constant",
) -> DType: ) -> features.DType:
if isinstance(inpt, features._Feature): if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, features._Feature)):
return pad_image_tensor(inpt, padding, fill=fill, padding_mode=padding_mode)
elif isinstance(inpt, features._Feature):
return inpt.pad(padding, fill=fill, padding_mode=padding_mode) return inpt.pad(padding, fill=fill, padding_mode=padding_mode)
elif isinstance(inpt, PIL.Image.Image):
return pad_image_pil(inpt, padding, fill=fill, padding_mode=padding_mode)
else: else:
# This cast does Sequence[int] -> List[int] and is required to make mypy happy return pad_image_pil(inpt, padding, fill=fill, padding_mode=padding_mode)
if not isinstance(padding, int):
padding = list(padding)
fill = _convert_fill_arg(fill)
return pad_image_tensor(inpt, padding, fill=fill, padding_mode=padding_mode)
crop_image_tensor = _FT.crop crop_image_tensor = _FT.crop
...@@ -746,13 +736,13 @@ def crop_mask(mask: torch.Tensor, top: int, left: int, height: int, width: int) ...@@ -746,13 +736,13 @@ def crop_mask(mask: torch.Tensor, top: int, left: int, height: int, width: int)
return crop_image_tensor(mask, top, left, height, width) return crop_image_tensor(mask, top, left, height, width)
def crop(inpt: DType, top: int, left: int, height: int, width: int) -> DType: def crop(inpt: features.DType, top: int, left: int, height: int, width: int) -> features.DType:
if isinstance(inpt, features._Feature): if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, features._Feature)):
return crop_image_tensor(inpt, top, left, height, width)
elif isinstance(inpt, features._Feature):
return inpt.crop(top, left, height, width) return inpt.crop(top, left, height, width)
elif isinstance(inpt, PIL.Image.Image):
return crop_image_pil(inpt, top, left, height, width)
else: else:
return crop_image_tensor(inpt, top, left, height, width) return crop_image_pil(inpt, top, left, height, width)
def perspective_image_tensor( def perspective_image_tensor(
...@@ -764,6 +754,7 @@ def perspective_image_tensor( ...@@ -764,6 +754,7 @@ def perspective_image_tensor(
return _FT.perspective(img, perspective_coeffs, interpolation=interpolation.value, fill=fill) return _FT.perspective(img, perspective_coeffs, interpolation=interpolation.value, fill=fill)
@torch.jit.unused
def perspective_image_pil( def perspective_image_pil(
img: PIL.Image.Image, img: PIL.Image.Image,
perspective_coeffs: List[float], perspective_coeffs: List[float],
...@@ -876,22 +867,20 @@ def perspective_mask(mask: torch.Tensor, perspective_coeffs: List[float]) -> tor ...@@ -876,22 +867,20 @@ def perspective_mask(mask: torch.Tensor, perspective_coeffs: List[float]) -> tor
def perspective( def perspective(
inpt: DType, inpt: features.DType,
startpoints: List[List[int]], startpoints: List[List[int]],
endpoints: List[List[int]], endpoints: List[List[int]],
interpolation: InterpolationMode = InterpolationMode.BILINEAR, interpolation: InterpolationMode = InterpolationMode.BILINEAR,
fill: Optional[Union[int, float, Sequence[int], Sequence[float]]] = None, fill: Optional[Union[int, float, List[float]]] = None,
) -> DType: ) -> features.DType:
perspective_coeffs = _get_perspective_coeffs(startpoints, endpoints) perspective_coeffs = _get_perspective_coeffs(startpoints, endpoints)
if isinstance(inpt, features._Feature): if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, features._Feature)):
return perspective_image_tensor(inpt, perspective_coeffs, interpolation=interpolation, fill=fill)
elif isinstance(inpt, features._Feature):
return inpt.perspective(perspective_coeffs, interpolation=interpolation, fill=fill) return inpt.perspective(perspective_coeffs, interpolation=interpolation, fill=fill)
elif isinstance(inpt, PIL.Image.Image):
return perspective_image_pil(inpt, perspective_coeffs, interpolation=interpolation, fill=fill)
else: else:
fill = _convert_fill_arg(fill) return perspective_image_pil(inpt, perspective_coeffs, interpolation=interpolation, fill=fill)
return perspective_image_tensor(inpt, perspective_coeffs, interpolation=interpolation, fill=fill)
def elastic_image_tensor( def elastic_image_tensor(
...@@ -903,15 +892,14 @@ def elastic_image_tensor( ...@@ -903,15 +892,14 @@ def elastic_image_tensor(
return _FT.elastic_transform(img, displacement, interpolation=interpolation.value, fill=fill) return _FT.elastic_transform(img, displacement, interpolation=interpolation.value, fill=fill)
@torch.jit.unused
def elastic_image_pil( def elastic_image_pil(
img: PIL.Image.Image, img: PIL.Image.Image,
displacement: torch.Tensor, displacement: torch.Tensor,
interpolation: InterpolationMode = InterpolationMode.BILINEAR, interpolation: InterpolationMode = InterpolationMode.BILINEAR,
fill: Optional[Union[int, float, Sequence[int], Sequence[float]]] = None, fill: Optional[Union[int, float, List[float]]] = None,
) -> PIL.Image.Image: ) -> PIL.Image.Image:
t_img = pil_to_tensor(img) t_img = pil_to_tensor(img)
fill = _convert_fill_arg(fill)
output = elastic_image_tensor(t_img, displacement, interpolation=interpolation, fill=fill) output = elastic_image_tensor(t_img, displacement, interpolation=interpolation, fill=fill)
return to_pil_image(output, mode=img.mode) return to_pil_image(output, mode=img.mode)
...@@ -972,19 +960,17 @@ def elastic_mask(mask: torch.Tensor, displacement: torch.Tensor) -> torch.Tensor ...@@ -972,19 +960,17 @@ def elastic_mask(mask: torch.Tensor, displacement: torch.Tensor) -> torch.Tensor
def elastic( def elastic(
inpt: DType, inpt: features.DType,
displacement: torch.Tensor, displacement: torch.Tensor,
interpolation: InterpolationMode = InterpolationMode.BILINEAR, interpolation: InterpolationMode = InterpolationMode.BILINEAR,
fill: Optional[Union[int, float, Sequence[int], Sequence[float]]] = None, fill: Optional[Union[int, float, List[float]]] = None,
) -> DType: ) -> features.DType:
if isinstance(inpt, features._Feature): if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, features._Feature)):
return elastic_image_tensor(inpt, displacement, interpolation=interpolation, fill=fill)
elif isinstance(inpt, features._Feature):
return inpt.elastic(displacement, interpolation=interpolation, fill=fill) return inpt.elastic(displacement, interpolation=interpolation, fill=fill)
elif isinstance(inpt, PIL.Image.Image):
return elastic_image_pil(inpt, displacement, interpolation=interpolation, fill=fill)
else: else:
fill = _convert_fill_arg(fill) return elastic_image_pil(inpt, displacement, interpolation=interpolation, fill=fill)
return elastic_image_tensor(inpt, displacement, interpolation=interpolation, fill=fill)
elastic_transform = elastic elastic_transform = elastic
...@@ -1032,6 +1018,7 @@ def center_crop_image_tensor(img: torch.Tensor, output_size: List[int]) -> torch ...@@ -1032,6 +1018,7 @@ def center_crop_image_tensor(img: torch.Tensor, output_size: List[int]) -> torch
return crop_image_tensor(img, crop_top, crop_left, crop_height, crop_width) return crop_image_tensor(img, crop_top, crop_left, crop_height, crop_width)
@torch.jit.unused
def center_crop_image_pil(img: PIL.Image.Image, output_size: List[int]) -> PIL.Image.Image: def center_crop_image_pil(img: PIL.Image.Image, output_size: List[int]) -> PIL.Image.Image:
crop_height, crop_width = _center_crop_parse_output_size(output_size) crop_height, crop_width = _center_crop_parse_output_size(output_size)
_, image_height, image_width = get_dimensions_image_pil(img) _, image_height, image_width = get_dimensions_image_pil(img)
...@@ -1074,13 +1061,13 @@ def center_crop_mask(mask: torch.Tensor, output_size: List[int]) -> torch.Tensor ...@@ -1074,13 +1061,13 @@ def center_crop_mask(mask: torch.Tensor, output_size: List[int]) -> torch.Tensor
return output return output
def center_crop(inpt: DType, output_size: List[int]) -> DType: def center_crop(inpt: features.DType, output_size: List[int]) -> features.DType:
if isinstance(inpt, features._Feature): if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, features._Feature)):
return center_crop_image_tensor(inpt, output_size)
elif isinstance(inpt, features._Feature):
return inpt.center_crop(output_size) return inpt.center_crop(output_size)
elif isinstance(inpt, PIL.Image.Image):
return center_crop_image_pil(inpt, output_size)
else: else:
return center_crop_image_tensor(inpt, output_size) return center_crop_image_pil(inpt, output_size)
def resized_crop_image_tensor( def resized_crop_image_tensor(
...@@ -1097,6 +1084,7 @@ def resized_crop_image_tensor( ...@@ -1097,6 +1084,7 @@ def resized_crop_image_tensor(
return resize_image_tensor(img, size, interpolation=interpolation, antialias=antialias) return resize_image_tensor(img, size, interpolation=interpolation, antialias=antialias)
@torch.jit.unused
def resized_crop_image_pil( def resized_crop_image_pil(
img: PIL.Image.Image, img: PIL.Image.Image,
top: int, top: int,
...@@ -1136,7 +1124,7 @@ def resized_crop_mask( ...@@ -1136,7 +1124,7 @@ def resized_crop_mask(
def resized_crop( def resized_crop(
inpt: DType, inpt: features.DType,
top: int, top: int,
left: int, left: int,
height: int, height: int,
...@@ -1144,17 +1132,17 @@ def resized_crop( ...@@ -1144,17 +1132,17 @@ def resized_crop(
size: List[int], size: List[int],
interpolation: InterpolationMode = InterpolationMode.BILINEAR, interpolation: InterpolationMode = InterpolationMode.BILINEAR,
antialias: Optional[bool] = None, antialias: Optional[bool] = None,
) -> DType: ) -> features.DType:
if isinstance(inpt, features._Feature): if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, features._Feature)):
antialias = False if antialias is None else antialias
return inpt.resized_crop(top, left, height, width, antialias=antialias, size=size, interpolation=interpolation)
elif isinstance(inpt, PIL.Image.Image):
return resized_crop_image_pil(inpt, top, left, height, width, size=size, interpolation=interpolation)
else:
antialias = False if antialias is None else antialias antialias = False if antialias is None else antialias
return resized_crop_image_tensor( return resized_crop_image_tensor(
inpt, top, left, height, width, antialias=antialias, size=size, interpolation=interpolation inpt, top, left, height, width, antialias=antialias, size=size, interpolation=interpolation
) )
elif isinstance(inpt, features._Feature):
antialias = False if antialias is None else antialias
return inpt.resized_crop(top, left, height, width, antialias=antialias, size=size, interpolation=interpolation)
else:
return resized_crop_image_pil(inpt, top, left, height, width, size=size, interpolation=interpolation)
def _parse_five_crop_size(size: List[int]) -> List[int]: def _parse_five_crop_size(size: List[int]) -> List[int]:
...@@ -1188,6 +1176,7 @@ def five_crop_image_tensor( ...@@ -1188,6 +1176,7 @@ def five_crop_image_tensor(
return tl, tr, bl, br, center return tl, tr, bl, br, center
@torch.jit.unused
def five_crop_image_pil( def five_crop_image_pil(
img: PIL.Image.Image, size: List[int] img: PIL.Image.Image, size: List[int]
) -> Tuple[PIL.Image.Image, PIL.Image.Image, PIL.Image.Image, PIL.Image.Image, PIL.Image.Image]: ) -> Tuple[PIL.Image.Image, PIL.Image.Image, PIL.Image.Image, PIL.Image.Image, PIL.Image.Image]:
...@@ -1207,11 +1196,13 @@ def five_crop_image_pil( ...@@ -1207,11 +1196,13 @@ def five_crop_image_pil(
return tl, tr, bl, br, center return tl, tr, bl, br, center
def five_crop(inpt: DType, size: List[int]) -> Tuple[DType, DType, DType, DType, DType]: def five_crop(
# TODO: consider breaking BC here to return List[DType] to align this op with `ten_crop` inpt: features.ImageType, size: List[int]
) -> Tuple[features.ImageType, features.ImageType, features.ImageType, features.ImageType, features.ImageType]:
# TODO: consider breaking BC here to return List[features.ImageType] to align this op with `ten_crop`
if isinstance(inpt, torch.Tensor): if isinstance(inpt, torch.Tensor):
output = five_crop_image_tensor(inpt, size) output = five_crop_image_tensor(inpt, size)
if isinstance(inpt, features.Image): if not torch.jit.is_scripting() and isinstance(inpt, features.Image):
output = tuple(features.Image.new_like(inpt, item) for item in output) # type: ignore[assignment] output = tuple(features.Image.new_like(inpt, item) for item in output) # type: ignore[assignment]
return output return output
else: # isinstance(inpt, PIL.Image.Image): else: # isinstance(inpt, PIL.Image.Image):
...@@ -1231,6 +1222,7 @@ def ten_crop_image_tensor(img: torch.Tensor, size: List[int], vertical_flip: boo ...@@ -1231,6 +1222,7 @@ def ten_crop_image_tensor(img: torch.Tensor, size: List[int], vertical_flip: boo
return [tl, tr, bl, br, center, tl_flip, tr_flip, bl_flip, br_flip, center_flip] return [tl, tr, bl, br, center, tl_flip, tr_flip, bl_flip, br_flip, center_flip]
@torch.jit.unused
def ten_crop_image_pil(img: PIL.Image.Image, size: List[int], vertical_flip: bool = False) -> List[PIL.Image.Image]: def ten_crop_image_pil(img: PIL.Image.Image, size: List[int], vertical_flip: bool = False) -> List[PIL.Image.Image]:
tl, tr, bl, br, center = five_crop_image_pil(img, size) tl, tr, bl, br, center = five_crop_image_pil(img, size)
...@@ -1244,10 +1236,10 @@ def ten_crop_image_pil(img: PIL.Image.Image, size: List[int], vertical_flip: boo ...@@ -1244,10 +1236,10 @@ def ten_crop_image_pil(img: PIL.Image.Image, size: List[int], vertical_flip: boo
return [tl, tr, bl, br, center, tl_flip, tr_flip, bl_flip, br_flip, center_flip] return [tl, tr, bl, br, center, tl_flip, tr_flip, bl_flip, br_flip, center_flip]
def ten_crop(inpt: DType, size: List[int], vertical_flip: bool = False) -> List[DType]: def ten_crop(inpt: features.ImageType, size: List[int], vertical_flip: bool = False) -> List[features.ImageType]:
if isinstance(inpt, torch.Tensor): if isinstance(inpt, torch.Tensor):
output = ten_crop_image_tensor(inpt, size, vertical_flip=vertical_flip) output = ten_crop_image_tensor(inpt, size, vertical_flip=vertical_flip)
if isinstance(inpt, features.Image): if not torch.jit.is_scripting() and isinstance(inpt, features.Image):
output = [features.Image.new_like(inpt, item) for item in output] output = [features.Image.new_like(inpt, item) for item in output]
return output return output
else: # isinstance(inpt, PIL.Image.Image): else: # isinstance(inpt, PIL.Image.Image):
......
from typing import Any, List, Optional, Tuple, Union from typing import cast, List, Optional, Tuple
import PIL.Image import PIL.Image
import torch import torch
from torchvision.prototype import features from torchvision.prototype import features
from torchvision.prototype.features import BoundingBoxFormat, ColorSpace, Image from torchvision.prototype.features import BoundingBoxFormat, ColorSpace
from torchvision.transforms import functional_pil as _FP, functional_tensor as _FT from torchvision.transforms import functional_pil as _FP, functional_tensor as _FT
get_dimensions_image_tensor = _FT.get_dimensions get_dimensions_image_tensor = _FT.get_dimensions
get_dimensions_image_pil = _FP.get_dimensions get_dimensions_image_pil = _FP.get_dimensions
def get_chw(image: Union[PIL.Image.Image, torch.Tensor, features.Image]) -> Tuple[int, int, int]: # TODO: Should this be prefixed with `_` similar to other methods that don't get exposed by init?
if isinstance(image, features.Image): def get_chw(image: features.ImageType) -> Tuple[int, int, int]:
if isinstance(image, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(image, features.Image)):
channels, height, width = get_dimensions_image_tensor(image)
elif isinstance(image, features.Image):
channels = image.num_channels channels = image.num_channels
height, width = image.image_size height, width = image.image_size
elif features.is_simple_tensor(image): else: # isinstance(image, PIL.Image.Image)
channels, height, width = get_dimensions_image_tensor(image)
elif isinstance(image, PIL.Image.Image):
channels, height, width = get_dimensions_image_pil(image) channels, height, width = get_dimensions_image_pil(image)
else:
raise TypeError(f"unable to get image dimensions from object of type {type(image).__name__}")
return channels, height, width return channels, height, width
...@@ -30,11 +29,11 @@ def get_chw(image: Union[PIL.Image.Image, torch.Tensor, features.Image]) -> Tupl ...@@ -30,11 +29,11 @@ def get_chw(image: Union[PIL.Image.Image, torch.Tensor, features.Image]) -> Tupl
# detailed above. # detailed above.
def get_dimensions(image: Union[PIL.Image.Image, torch.Tensor, features.Image]) -> List[int]: def get_dimensions(image: features.ImageType) -> List[int]:
return list(get_chw(image)) return list(get_chw(image))
def get_num_channels(image: Union[PIL.Image.Image, torch.Tensor, features.Image]) -> int: def get_num_channels(image: features.ImageType) -> int:
num_channels, *_ = get_chw(image) num_channels, *_ = get_chw(image)
return num_channels return num_channels
...@@ -44,7 +43,7 @@ def get_num_channels(image: Union[PIL.Image.Image, torch.Tensor, features.Image] ...@@ -44,7 +43,7 @@ def get_num_channels(image: Union[PIL.Image.Image, torch.Tensor, features.Image]
get_image_num_channels = get_num_channels get_image_num_channels = get_num_channels
def get_spatial_size(image: Union[PIL.Image.Image, torch.Tensor, features.Image]) -> List[int]: def get_spatial_size(image: features.ImageType) -> List[int]:
_, *size = get_chw(image) _, *size = get_chw(image)
return size return size
...@@ -192,6 +191,7 @@ _COLOR_SPACE_TO_PIL_MODE = { ...@@ -192,6 +191,7 @@ _COLOR_SPACE_TO_PIL_MODE = {
} }
@torch.jit.unused
def convert_color_space_image_pil( def convert_color_space_image_pil(
image: PIL.Image.Image, color_space: ColorSpace, copy: bool = True image: PIL.Image.Image, color_space: ColorSpace, copy: bool = True
) -> PIL.Image.Image: ) -> PIL.Image.Image:
...@@ -208,17 +208,12 @@ def convert_color_space_image_pil( ...@@ -208,17 +208,12 @@ def convert_color_space_image_pil(
def convert_color_space( def convert_color_space(
inpt: Union[PIL.Image.Image, torch.Tensor, features._Feature], inpt: features.ImageType,
*,
color_space: ColorSpace, color_space: ColorSpace,
old_color_space: Optional[ColorSpace] = None, old_color_space: Optional[ColorSpace] = None,
copy: bool = True, copy: bool = True,
) -> Any: ) -> features.ImageType:
if isinstance(inpt, Image): if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, features.Image)):
return inpt.to_color_space(color_space, copy=copy)
elif isinstance(inpt, PIL.Image.Image):
return convert_color_space_image_pil(inpt, color_space, copy=copy)
else:
if old_color_space is None: if old_color_space is None:
raise RuntimeError( raise RuntimeError(
"In order to convert the color space of simple tensor images, " "In order to convert the color space of simple tensor images, "
...@@ -227,3 +222,7 @@ def convert_color_space( ...@@ -227,3 +222,7 @@ def convert_color_space(
return convert_color_space_image_tensor( return convert_color_space_image_tensor(
inpt, old_color_space=old_color_space, new_color_space=color_space, copy=copy inpt, old_color_space=old_color_space, new_color_space=color_space, copy=copy
) )
elif isinstance(inpt, features.Image):
return inpt.to_color_space(color_space, copy=copy)
else:
return cast(features.ImageType, convert_color_space_image_pil(inpt, color_space, copy=copy))
from typing import List, Optional, Union from typing import List, Optional
import PIL.Image import PIL.Image
import torch import torch
...@@ -7,16 +7,15 @@ from torchvision.transforms import functional_tensor as _FT ...@@ -7,16 +7,15 @@ from torchvision.transforms import functional_tensor as _FT
from torchvision.transforms.functional import pil_to_tensor, to_pil_image from torchvision.transforms.functional import pil_to_tensor, to_pil_image
# shortcut type # Due to torch.jit.script limitation we keep TensorImageType as torch.Tensor
DType = Union[torch.Tensor, PIL.Image.Image, features._Feature] # instead of Union[torch.Tensor, features.Image]
TensorImageType = torch.Tensor
normalize_image_tensor = _FT.normalize normalize_image_tensor = _FT.normalize
def normalize( def normalize(inpt: TensorImageType, mean: List[float], std: List[float], inplace: bool = False) -> torch.Tensor:
inpt: Union[torch.Tensor, features.Image], mean: List[float], std: List[float], inplace: bool = False
) -> torch.Tensor:
if not isinstance(inpt, torch.Tensor): if not isinstance(inpt, torch.Tensor):
raise TypeError(f"img should be Tensor Image. Got {type(inpt)}") raise TypeError(f"img should be Tensor Image. Got {type(inpt)}")
else: else:
...@@ -54,6 +53,7 @@ def gaussian_blur_image_tensor( ...@@ -54,6 +53,7 @@ def gaussian_blur_image_tensor(
return _FT.gaussian_blur(img, kernel_size, sigma) return _FT.gaussian_blur(img, kernel_size, sigma)
@torch.jit.unused
def gaussian_blur_image_pil( def gaussian_blur_image_pil(
img: PIL.Image.Image, kernel_size: List[int], sigma: Optional[List[float]] = None img: PIL.Image.Image, kernel_size: List[int], sigma: Optional[List[float]] = None
) -> PIL.Image.Image: ) -> PIL.Image.Image:
...@@ -62,10 +62,10 @@ def gaussian_blur_image_pil( ...@@ -62,10 +62,10 @@ def gaussian_blur_image_pil(
return to_pil_image(output, mode=img.mode) return to_pil_image(output, mode=img.mode)
def gaussian_blur(inpt: DType, kernel_size: List[int], sigma: Optional[List[float]] = None) -> DType: def gaussian_blur(inpt: features.DType, kernel_size: List[int], sigma: Optional[List[float]] = None) -> features.DType:
if isinstance(inpt, features._Feature): if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, features._Feature)):
return gaussian_blur_image_tensor(inpt, kernel_size=kernel_size, sigma=sigma)
elif isinstance(inpt, features._Feature):
return inpt.gaussian_blur(kernel_size=kernel_size, sigma=sigma) return inpt.gaussian_blur(kernel_size=kernel_size, sigma=sigma)
elif isinstance(inpt, PIL.Image.Image):
return gaussian_blur_image_pil(inpt, kernel_size=kernel_size, sigma=sigma)
else: else:
return gaussian_blur_image_tensor(inpt, kernel_size=kernel_size, sigma=sigma) return gaussian_blur_image_pil(inpt, kernel_size=kernel_size, sigma=sigma)
...@@ -10,6 +10,7 @@ from torchvision.prototype.utils._internal import ReadOnlyTensorBuffer ...@@ -10,6 +10,7 @@ from torchvision.prototype.utils._internal import ReadOnlyTensorBuffer
from torchvision.transforms import functional as _F from torchvision.transforms import functional as _F
@torch.jit.unused
def decode_image_with_pil(encoded_image: torch.Tensor) -> features.Image: def decode_image_with_pil(encoded_image: torch.Tensor) -> features.Image:
image = torch.as_tensor(np.array(PIL.Image.open(ReadOnlyTensorBuffer(encoded_image)), copy=True)) image = torch.as_tensor(np.array(PIL.Image.open(ReadOnlyTensorBuffer(encoded_image)), copy=True))
if image.ndim == 2: if image.ndim == 2:
...@@ -17,11 +18,13 @@ def decode_image_with_pil(encoded_image: torch.Tensor) -> features.Image: ...@@ -17,11 +18,13 @@ def decode_image_with_pil(encoded_image: torch.Tensor) -> features.Image:
return features.Image(image.permute(2, 0, 1)) return features.Image(image.permute(2, 0, 1))
@torch.jit.unused
def decode_video_with_av(encoded_video: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, Dict[str, Any]]: def decode_video_with_av(encoded_video: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, Dict[str, Any]]:
with unittest.mock.patch("torchvision.io.video.os.path.exists", return_value=True): with unittest.mock.patch("torchvision.io.video.os.path.exists", return_value=True):
return read_video(ReadOnlyTensorBuffer(encoded_video)) # type: ignore[arg-type] return read_video(ReadOnlyTensorBuffer(encoded_video)) # type: ignore[arg-type]
@torch.jit.unused
def to_image_tensor(image: Union[torch.Tensor, PIL.Image.Image, np.ndarray]) -> features.Image: def to_image_tensor(image: Union[torch.Tensor, PIL.Image.Image, np.ndarray]) -> features.Image:
if isinstance(image, np.ndarray): if isinstance(image, np.ndarray):
output = torch.from_numpy(image) output = torch.from_numpy(image)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment