Unverified Commit 841b9a19 authored by Vasilis Vryniotis's avatar Vasilis Vryniotis Committed by GitHub
Browse files

Make prototype `F` JIT-scriptable (#6584)



* Improve existing low kernel test.

* Add new midlevel jit-scriptability test (failing).

* Remove duplicate aliases from kernel tests.

* Fixing colour kernels.

* Fixing deprecated kernels.

* fix mypy

* Silence mypy instead of fixing to avoid performance penalty

* Fixing augment kernels.

* Fixing augment meta.

* Remove is_tracing calls.

* Add fake ImageType and DType

* Fixing type conversion kernels.

* Fixing misc kernels.

* partial fix geometry

* Remove mutable default from `_pad_with_vector_fill()` + all other unnecessary defaults.

* Fix geometry ops

* Fixing tests

* Removed xfail for jit tests on midlevel ops
Co-authored-by: default avatarvfdev-5 <vfdev.5@gmail.com>
parent 3a1f05ed
......@@ -113,14 +113,21 @@ DISPATCHER_INFOS = [
features.Mask: F.pad_mask,
},
),
DispatcherInfo(
F.perspective,
kernels={
features.Image: F.perspective_image_tensor,
features.BoundingBox: F.perspective_bounding_box,
features.Mask: F.perspective_mask,
},
),
# FIXME:
# RuntimeError: perspective() is missing value for argument 'startpoints'.
# Declaration: perspective(Tensor inpt, int[][] startpoints, int[][] endpoints,
# Enum<__torch__.torchvision.transforms.functional.InterpolationMode> interpolation=Enum<InterpolationMode.BILINEAR>,
# Union(float[], float, int, NoneType) fill=None) -> Tensor
#
# This is probably due to the fact that F.perspective does not have the same signature as F.perspective_image_tensor
# DispatcherInfo(
# F.perspective,
# kernels={
# features.Image: F.perspective_image_tensor,
# features.BoundingBox: F.perspective_bounding_box,
# features.Mask: F.perspective_mask,
# },
# ),
DispatcherInfo(
F.center_crop,
kernels={
......
......@@ -376,6 +376,9 @@ class TestPad:
inpt = mocker.MagicMock(spec=features.Image)
_ = transform(inpt)
fill = transforms.functional._geometry._convert_fill_arg(fill)
if isinstance(padding, tuple):
padding = list(padding)
fn.assert_called_once_with(inpt, padding=padding, fill=fill, padding_mode=padding_mode)
@pytest.mark.parametrize("fill", [12, {features.Image: 12, features.Mask: 34}])
......@@ -389,14 +392,17 @@ class TestPad:
_ = transform(inpt)
if isinstance(fill, int):
fill = transforms.functional._geometry._convert_fill_arg(fill)
calls = [
mocker.call(image, padding=1, fill=fill, padding_mode="constant"),
mocker.call(mask, padding=1, fill=fill, padding_mode="constant"),
]
else:
fill_img = transforms.functional._geometry._convert_fill_arg(fill[type(image)])
fill_mask = transforms.functional._geometry._convert_fill_arg(fill[type(mask)])
calls = [
mocker.call(image, padding=1, fill=fill[type(image)], padding_mode="constant"),
mocker.call(mask, padding=1, fill=fill[type(mask)], padding_mode="constant"),
mocker.call(image, padding=1, fill=fill_img, padding_mode="constant"),
mocker.call(mask, padding=1, fill=fill_mask, padding_mode="constant"),
]
fn.assert_has_calls(calls)
......@@ -447,6 +453,7 @@ class TestRandomZoomOut:
torch.rand(1) # random apply changes random state
params = transform._get_params(inpt)
fill = transforms.functional._geometry._convert_fill_arg(fill)
fn.assert_called_once_with(inpt, **params, fill=fill)
@pytest.mark.parametrize("fill", [12, {features.Image: 12, features.Mask: 34}])
......@@ -465,14 +472,17 @@ class TestRandomZoomOut:
params = transform._get_params(inpt)
if isinstance(fill, int):
fill = transforms.functional._geometry._convert_fill_arg(fill)
calls = [
mocker.call(image, **params, fill=fill),
mocker.call(mask, **params, fill=fill),
]
else:
fill_img = transforms.functional._geometry._convert_fill_arg(fill[type(image)])
fill_mask = transforms.functional._geometry._convert_fill_arg(fill[type(mask)])
calls = [
mocker.call(image, **params, fill=fill[type(image)]),
mocker.call(mask, **params, fill=fill[type(mask)]),
mocker.call(image, **params, fill=fill_img),
mocker.call(mask, **params, fill=fill_mask),
]
fn.assert_has_calls(calls)
......@@ -533,6 +543,7 @@ class TestRandomRotation:
torch.manual_seed(12)
params = transform._get_params(inpt)
fill = transforms.functional._geometry._convert_fill_arg(fill)
fn.assert_called_once_with(inpt, **params, interpolation=interpolation, expand=expand, fill=fill, center=center)
@pytest.mark.parametrize("angle", [34, -87])
......@@ -670,6 +681,7 @@ class TestRandomAffine:
torch.manual_seed(12)
params = transform._get_params(inpt)
fill = transforms.functional._geometry._convert_fill_arg(fill)
fn.assert_called_once_with(inpt, **params, interpolation=interpolation, fill=fill, center=center)
......@@ -917,6 +929,7 @@ class TestRandomPerspective:
torch.rand(1) # random apply changes random state
params = transform._get_params(inpt)
fill = transforms.functional._geometry._convert_fill_arg(fill)
fn.assert_called_once_with(inpt, **params, fill=fill, interpolation=interpolation)
......@@ -986,6 +999,7 @@ class TestElasticTransform:
transform._get_params = mocker.MagicMock()
_ = transform(inpt)
params = transform._get_params(inpt)
fill = transforms.functional._geometry._convert_fill_arg(fill)
fn.assert_called_once_with(inpt, **params, fill=fill, interpolation=interpolation)
......@@ -1609,6 +1623,7 @@ class TestFixedSizeCrop:
if not needs_crop:
assert args[0] is inpt_sentinel
assert args[1] is padding_sentinel
fill_sentinel = transforms.functional._geometry._convert_fill_arg(fill_sentinel)
assert kwargs == dict(fill=fill_sentinel, padding_mode=padding_mode_sentinel)
else:
mock_pad.assert_not_called()
......
......@@ -9,7 +9,6 @@ from torchvision.prototype import features
class TestCommon:
@pytest.mark.xfail(reason="dispatchers are currently not scriptable")
@pytest.mark.parametrize(
("info", "args_kwargs"),
[
......
......@@ -407,27 +407,74 @@ def erase_image_tensor():
yield ArgsKwargs(image, i=1, j=2, h=6, w=7, v=torch.rand(c, 6, 7))
_KERNEL_TYPES = {"_image_tensor", "_image_pil", "_mask", "_bounding_box", "_label"}
def _distinct_callables(callable_names):
# Ensure we deduplicate callables (due to aliases) without losing the names on the new API
remove = set()
distinct = set()
for name in callable_names:
item = F.__dict__[name]
if item not in distinct:
distinct.add(item)
else:
remove.add(name)
callable_names -= remove
# create tuple and sort by name
return sorted([(name, F.__dict__[name]) for name in callable_names], key=lambda t: t[0])
def _get_distinct_kernels():
kernel_names = {
name
for name, f in F.__dict__.items()
if callable(f) and not name.startswith("_") and any(name.endswith(k) for k in _KERNEL_TYPES)
}
return _distinct_callables(kernel_names)
def _get_distinct_midlevels():
midlevel_names = {
name
for name, f in F.__dict__.items()
if callable(f) and not name.startswith("_") and not any(name.endswith(k) for k in _KERNEL_TYPES)
}
return _distinct_callables(midlevel_names)
@pytest.mark.parametrize(
"kernel",
[
pytest.param(kernel, id=name)
for name, kernel in F.__dict__.items()
if not name.startswith("_")
and callable(kernel)
and any(feature_type in name for feature_type in {"image", "mask", "bounding_box", "label"})
and "pil" not in name
and name
for name, kernel in _get_distinct_kernels()
if not name.endswith("_image_pil") and name not in {"to_image_tensor"}
],
)
def test_scriptable_kernel(kernel):
jit.script(kernel) # TODO: pass data through it
@pytest.mark.parametrize(
"midlevel",
[
pytest.param(midlevel, id=name)
for name, midlevel in _get_distinct_midlevels()
if name
not in {
"to_image_tensor",
"get_num_channels",
"get_spatial_size",
"get_image_num_channels",
"get_image_size",
"InterpolationMode",
"decode_image_with_pil",
"decode_video_with_av",
"pil_to_tensor",
"to_grayscale",
"to_pil_image",
"to_tensor",
}
],
)
def test_scriptable(kernel):
jit.script(kernel)
def test_scriptable_midlevel(midlevel):
jit.script(midlevel) # TODO: pass data through it
# Test below is intended to test mid-level op vs low-level ops it calls
......@@ -439,8 +486,8 @@ def test_scriptable(kernel):
[
pytest.param(func, id=name)
for name, func in F.__dict__.items()
if not name.startswith("_")
and callable(func)
if not name.startswith("_") and callable(func)
# TODO: remove aliases
and all(feature_type not in name for feature_type in {"image", "mask", "bounding_box", "label", "pil"})
and name
not in {
......
from ._bounding_box import BoundingBox, BoundingBoxFormat
from ._encoded import EncodedData, EncodedImage, EncodedVideo
from ._feature import _Feature, is_simple_tensor
from ._image import ColorSpace, Image
from ._feature import _Feature, DType, is_simple_tensor
from ._image import ColorSpace, Image, ImageType
from ._label import Label, OneHotLabel
from ._mask import Mask
......@@ -10,6 +10,11 @@ from torchvision.transforms import InterpolationMode
F = TypeVar("F", bound="_Feature")
# Due to torch.jit.script limitation we keep DType as torch.Tensor
# instead of Union[torch.Tensor, PIL.Image.Image, features._Feature]
DType = torch.Tensor
def is_simple_tensor(inpt: Any) -> bool:
return isinstance(inpt, torch.Tensor) and not isinstance(inpt, _Feature)
......
......@@ -12,6 +12,11 @@ from ._bounding_box import BoundingBox
from ._feature import _Feature
# Due to torch.jit.script limitation we keep ImageType as torch.Tensor
# instead of Union[torch.Tensor, PIL.Image.Image, features.Image]
ImageType = torch.Tensor
class ColorSpace(StrEnum):
OTHER = StrEnum.auto()
GRAY = StrEnum.auto()
......@@ -32,6 +37,31 @@ class ColorSpace(StrEnum):
else:
return cls.OTHER
@staticmethod
def from_tensor_shape(shape: List[int]) -> ColorSpace:
return _from_tensor_shape(shape)
def _from_tensor_shape(shape: List[int]) -> ColorSpace:
# Needed as a standalone method for JIT
ndim = len(shape)
if ndim < 2:
return ColorSpace.OTHER
elif ndim == 2:
return ColorSpace.GRAY
num_channels = shape[-3]
if num_channels == 1:
return ColorSpace.GRAY
elif num_channels == 2:
return ColorSpace.GRAY_ALPHA
elif num_channels == 3:
return ColorSpace.RGB
elif num_channels == 4:
return ColorSpace.RGB_ALPHA
else:
return ColorSpace.OTHER
class Image(_Feature):
color_space: ColorSpace
......@@ -53,7 +83,7 @@ class Image(_Feature):
image = super().__new__(cls, data, requires_grad=requires_grad)
if color_space is None:
color_space = cls.guess_color_space(image)
color_space = ColorSpace.from_tensor_shape(image.shape) # type: ignore[arg-type]
if color_space == ColorSpace.OTHER:
warnings.warn("Unable to guess a specific color space. Consider passing it explicitly.")
elif isinstance(color_space, str):
......@@ -83,25 +113,6 @@ class Image(_Feature):
def num_channels(self) -> int:
return self.shape[-3]
@staticmethod
def guess_color_space(data: torch.Tensor) -> ColorSpace:
if data.ndim < 2:
return ColorSpace.OTHER
elif data.ndim == 2:
return ColorSpace.GRAY
num_channels = data.shape[-3]
if num_channels == 1:
return ColorSpace.GRAY
elif num_channels == 2:
return ColorSpace.GRAY_ALPHA
elif num_channels == 3:
return ColorSpace.RGB
elif num_channels == 4:
return ColorSpace.RGB_ALPHA
else:
return ColorSpace.OTHER
def to_color_space(self, color_space: Union[str, ColorSpace], copy: bool = True) -> Image:
if isinstance(color_space, str):
color_space = ColorSpace.from_str(color_space.upper())
......
......@@ -72,11 +72,10 @@ class _AutoAugmentBase(Transform):
# Fill = 0 is not equivalent to None, https://github.com/pytorch/vision/issues/6517
# So, we have to put fill as None if fill == 0
fill_: Optional[Union[int, float, Sequence[int], Sequence[float]]]
# This is due to BC with stable API which has fill = None by default
fill_ = F._geometry._convert_fill_arg(fill)
if isinstance(fill, int) and fill == 0:
fill_ = None
else:
fill_ = fill
if transform_id == "Identity":
return image
......
......@@ -252,7 +252,14 @@ class Pad(Transform):
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
fill = self.fill[type(inpt)]
return F.pad(inpt, padding=self.padding, fill=fill, padding_mode=self.padding_mode)
# This cast does Sequence[int] -> List[int] and is required to make mypy happy
padding = self.padding
if not isinstance(padding, int):
padding = list(padding)
fill = F._geometry._convert_fill_arg(fill)
return F.pad(inpt, padding=padding, fill=fill, padding_mode=self.padding_mode)
class RandomZoomOut(_RandomApplyTransform):
......@@ -290,6 +297,7 @@ class RandomZoomOut(_RandomApplyTransform):
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
fill = self.fill[type(inpt)]
fill = F._geometry._convert_fill_arg(fill)
return F.pad(inpt, **params, fill=fill)
......@@ -320,6 +328,7 @@ class RandomRotation(Transform):
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
fill = self.fill[type(inpt)]
fill = F._geometry._convert_fill_arg(fill)
return F.rotate(
inpt,
**params,
......@@ -401,6 +410,7 @@ class RandomAffine(Transform):
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
fill = self.fill[type(inpt)]
fill = F._geometry._convert_fill_arg(fill)
return F.affine(
inpt,
**params,
......@@ -480,8 +490,15 @@ class RandomCrop(Transform):
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
# TODO: (PERF) check for speed optimization if we avoid repeated pad calls
fill = self.fill[type(inpt)]
fill = F._geometry._convert_fill_arg(fill)
if self.padding is not None:
inpt = F.pad(inpt, padding=self.padding, fill=fill, padding_mode=self.padding_mode)
# This cast does Sequence[int] -> List[int] and is required to make mypy happy
padding = self.padding
if not isinstance(padding, int):
padding = list(padding)
inpt = F.pad(inpt, padding=padding, fill=fill, padding_mode=self.padding_mode)
if self.pad_if_needed:
input_width, input_height = params["input_width"], params["input_height"]
......@@ -543,6 +560,7 @@ class RandomPerspective(_RandomApplyTransform):
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
fill = self.fill[type(inpt)]
fill = F._geometry._convert_fill_arg(fill)
return F.perspective(
inpt,
**params,
......@@ -610,6 +628,7 @@ class ElasticTransform(Transform):
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
fill = self.fill[type(inpt)]
fill = F._geometry._convert_fill_arg(fill)
return F.elastic(
inpt,
**params,
......@@ -868,6 +887,7 @@ class FixedSizeCrop(Transform):
if params["needs_pad"]:
fill = self.fill[type(inpt)]
fill = F._geometry._convert_fill_arg(fill)
inpt = F.pad(inpt, params["padding"], fill=fill, padding_mode=self.padding_mode)
return inpt
......
from typing import Any, Dict, Optional, Union
from typing import Any, cast, Dict, Optional, Union
import numpy as np
import PIL.Image
......@@ -13,7 +13,7 @@ class DecodeImage(Transform):
_transformed_types = (features.EncodedImage,)
def _transform(self, inpt: torch.Tensor, params: Dict[str, Any]) -> features.Image:
return F.decode_image_with_pil(inpt)
return cast(features.Image, F.decode_image_with_pil(inpt))
class LabelToOneHot(Transform):
......@@ -50,7 +50,7 @@ class ToImageTensor(Transform):
def _transform(
self, inpt: Union[torch.Tensor, PIL.Image.Image, np.ndarray], params: Dict[str, Any]
) -> features.Image:
return F.to_image_tensor(inpt)
return cast(features.Image, F.to_image_tensor(inpt))
class ToImagePIL(Transform):
......
......@@ -74,7 +74,7 @@ from ._geometry import (
five_crop,
five_crop_image_pil,
five_crop_image_tensor,
hflip,
hflip, # TODO: Consider moving all pure alias definitions at the bottom of the file
horizontal_flip,
horizontal_flip_bounding_box,
horizontal_flip_image_pil,
......
from typing import Union
import PIL.Image
import torch
......@@ -11,6 +9,7 @@ from torchvision.transforms.functional import pil_to_tensor, to_pil_image
erase_image_tensor = _FT.erase
@torch.jit.unused
def erase_image_pil(
img: PIL.Image.Image, i: int, j: int, h: int, w: int, v: torch.Tensor, inplace: bool = False
) -> PIL.Image.Image:
......@@ -20,17 +19,17 @@ def erase_image_pil(
def erase(
inpt: Union[torch.Tensor, PIL.Image.Image, features.Image],
inpt: features.ImageType,
i: int,
j: int,
h: int,
w: int,
v: torch.Tensor,
inplace: bool = False,
) -> Union[torch.Tensor, PIL.Image.Image, features.Image]:
) -> features.ImageType:
if isinstance(inpt, torch.Tensor):
output = erase_image_tensor(inpt, i=i, j=j, h=h, w=w, v=v, inplace=inplace)
if isinstance(inpt, features.Image):
if not torch.jit.is_scripting() and isinstance(inpt, features.Image):
output = features.Image.new_like(inpt, output)
return output
else: # isinstance(inpt, PIL.Image.Image):
......
from typing import Union
import PIL.Image
import torch
from torchvision.prototype import features
from torchvision.transforms import functional_pil as _FP, functional_tensor as _FT
# shortcut type
DType = Union[torch.Tensor, PIL.Image.Image, features._Feature]
adjust_brightness_image_tensor = _FT.adjust_brightness
adjust_brightness_image_pil = _FP.adjust_brightness
def adjust_brightness(inpt: DType, brightness_factor: float) -> DType:
if isinstance(inpt, features._Feature):
def adjust_brightness(inpt: features.DType, brightness_factor: float) -> features.DType:
if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, features._Feature)):
return adjust_brightness_image_tensor(inpt, brightness_factor=brightness_factor)
elif isinstance(inpt, features._Feature):
return inpt.adjust_brightness(brightness_factor=brightness_factor)
elif isinstance(inpt, PIL.Image.Image):
return adjust_brightness_image_pil(inpt, brightness_factor=brightness_factor)
else:
return adjust_brightness_image_tensor(inpt, brightness_factor=brightness_factor)
return adjust_brightness_image_pil(inpt, brightness_factor=brightness_factor)
adjust_saturation_image_tensor = _FT.adjust_saturation
adjust_saturation_image_pil = _FP.adjust_saturation
def adjust_saturation(inpt: DType, saturation_factor: float) -> DType:
if isinstance(inpt, features._Feature):
def adjust_saturation(inpt: features.DType, saturation_factor: float) -> features.DType:
if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, features._Feature)):
return adjust_saturation_image_tensor(inpt, saturation_factor=saturation_factor)
elif isinstance(inpt, features._Feature):
return inpt.adjust_saturation(saturation_factor=saturation_factor)
elif isinstance(inpt, PIL.Image.Image):
return adjust_saturation_image_pil(inpt, saturation_factor=saturation_factor)
else:
return adjust_saturation_image_tensor(inpt, saturation_factor=saturation_factor)
return adjust_saturation_image_pil(inpt, saturation_factor=saturation_factor)
adjust_contrast_image_tensor = _FT.adjust_contrast
adjust_contrast_image_pil = _FP.adjust_contrast
def adjust_contrast(inpt: DType, contrast_factor: float) -> DType:
if isinstance(inpt, features._Feature):
def adjust_contrast(inpt: features.DType, contrast_factor: float) -> features.DType:
if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, features._Feature)):
return adjust_contrast_image_tensor(inpt, contrast_factor=contrast_factor)
elif isinstance(inpt, features._Feature):
return inpt.adjust_contrast(contrast_factor=contrast_factor)
elif isinstance(inpt, PIL.Image.Image):
return adjust_contrast_image_pil(inpt, contrast_factor=contrast_factor)
else:
return adjust_contrast_image_tensor(inpt, contrast_factor=contrast_factor)
return adjust_contrast_image_pil(inpt, contrast_factor=contrast_factor)
adjust_sharpness_image_tensor = _FT.adjust_sharpness
adjust_sharpness_image_pil = _FP.adjust_sharpness
def adjust_sharpness(inpt: DType, sharpness_factor: float) -> DType:
if isinstance(inpt, features._Feature):
def adjust_sharpness(inpt: features.DType, sharpness_factor: float) -> features.DType:
if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, features._Feature)):
return adjust_sharpness_image_tensor(inpt, sharpness_factor=sharpness_factor)
elif isinstance(inpt, features._Feature):
return inpt.adjust_sharpness(sharpness_factor=sharpness_factor)
elif isinstance(inpt, PIL.Image.Image):
return adjust_sharpness_image_pil(inpt, sharpness_factor=sharpness_factor)
else:
return adjust_sharpness_image_tensor(inpt, sharpness_factor=sharpness_factor)
return adjust_sharpness_image_pil(inpt, sharpness_factor=sharpness_factor)
adjust_hue_image_tensor = _FT.adjust_hue
adjust_hue_image_pil = _FP.adjust_hue
def adjust_hue(inpt: DType, hue_factor: float) -> DType:
if isinstance(inpt, features._Feature):
def adjust_hue(inpt: features.DType, hue_factor: float) -> features.DType:
if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, features._Feature)):
return adjust_hue_image_tensor(inpt, hue_factor=hue_factor)
elif isinstance(inpt, features._Feature):
return inpt.adjust_hue(hue_factor=hue_factor)
elif isinstance(inpt, PIL.Image.Image):
return adjust_hue_image_pil(inpt, hue_factor=hue_factor)
else:
return adjust_hue_image_tensor(inpt, hue_factor=hue_factor)
return adjust_hue_image_pil(inpt, hue_factor=hue_factor)
adjust_gamma_image_tensor = _FT.adjust_gamma
adjust_gamma_image_pil = _FP.adjust_gamma
def adjust_gamma(inpt: DType, gamma: float, gain: float = 1) -> DType:
if isinstance(inpt, features._Feature):
def adjust_gamma(inpt: features.DType, gamma: float, gain: float = 1) -> features.DType:
if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, features._Feature)):
return adjust_gamma_image_tensor(inpt, gamma=gamma, gain=gain)
elif isinstance(inpt, features._Feature):
return inpt.adjust_gamma(gamma=gamma, gain=gain)
elif isinstance(inpt, PIL.Image.Image):
return adjust_gamma_image_pil(inpt, gamma=gamma, gain=gain)
else:
return adjust_gamma_image_tensor(inpt, gamma=gamma, gain=gain)
return adjust_gamma_image_pil(inpt, gamma=gamma, gain=gain)
posterize_image_tensor = _FT.posterize
posterize_image_pil = _FP.posterize
def posterize(inpt: DType, bits: int) -> DType:
if isinstance(inpt, features._Feature):
def posterize(inpt: features.DType, bits: int) -> features.DType:
if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, features._Feature)):
return posterize_image_tensor(inpt, bits=bits)
elif isinstance(inpt, features._Feature):
return inpt.posterize(bits=bits)
elif isinstance(inpt, PIL.Image.Image):
return posterize_image_pil(inpt, bits=bits)
else:
return posterize_image_tensor(inpt, bits=bits)
return posterize_image_pil(inpt, bits=bits)
solarize_image_tensor = _FT.solarize
solarize_image_pil = _FP.solarize
def solarize(inpt: DType, threshold: float) -> DType:
if isinstance(inpt, features._Feature):
def solarize(inpt: features.DType, threshold: float) -> features.DType:
if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, features._Feature)):
return solarize_image_tensor(inpt, threshold=threshold)
elif isinstance(inpt, features._Feature):
return inpt.solarize(threshold=threshold)
elif isinstance(inpt, PIL.Image.Image):
return solarize_image_pil(inpt, threshold=threshold)
else:
return solarize_image_tensor(inpt, threshold=threshold)
return solarize_image_pil(inpt, threshold=threshold)
autocontrast_image_tensor = _FT.autocontrast
autocontrast_image_pil = _FP.autocontrast
def autocontrast(inpt: DType) -> DType:
if isinstance(inpt, features._Feature):
def autocontrast(inpt: features.DType) -> features.DType:
if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, features._Feature)):
return autocontrast_image_tensor(inpt)
elif isinstance(inpt, features._Feature):
return inpt.autocontrast()
elif isinstance(inpt, PIL.Image.Image):
return autocontrast_image_pil(inpt)
else:
return autocontrast_image_tensor(inpt)
return autocontrast_image_pil(inpt)
equalize_image_tensor = _FT.equalize
equalize_image_pil = _FP.equalize
def equalize(inpt: DType) -> DType:
if isinstance(inpt, features._Feature):
def equalize(inpt: features.DType) -> features.DType:
if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, features._Feature)):
return equalize_image_tensor(inpt)
elif isinstance(inpt, features._Feature):
return inpt.equalize()
elif isinstance(inpt, PIL.Image.Image):
return equalize_image_pil(inpt)
else:
return equalize_image_tensor(inpt)
return equalize_image_pil(inpt)
invert_image_tensor = _FT.invert
invert_image_pil = _FP.invert
def invert(inpt: DType) -> DType:
if isinstance(inpt, features._Feature):
def invert(inpt: features.DType) -> features.DType:
if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, features._Feature)):
return invert_image_tensor(inpt)
elif isinstance(inpt, features._Feature):
return inpt.invert()
elif isinstance(inpt, PIL.Image.Image):
return invert_image_pil(inpt)
else:
return invert_image_tensor(inpt)
return invert_image_pil(inpt)
import warnings
from typing import Any, List, Union
from typing import Any, List
import PIL.Image
import torch
......@@ -8,6 +8,12 @@ from torchvision.prototype import features
from torchvision.transforms import functional as _F
# Due to torch.jit.script limitation we keep LegacyImageType as torch.Tensor
# instead of Union[torch.Tensor, PIL.Image.Image]
LegacyImageType = torch.Tensor
@torch.jit.unused
def to_grayscale(inpt: PIL.Image.Image, num_output_channels: int = 1) -> PIL.Image.Image:
call = ", num_output_channels=3" if num_output_channels == 3 else ""
replacement = "convert_color_space(..., color_space=features.ColorSpace.GRAY)"
......@@ -21,10 +27,12 @@ def to_grayscale(inpt: PIL.Image.Image, num_output_channels: int = 1) -> PIL.Ima
return _F.to_grayscale(inpt, num_output_channels=num_output_channels)
def rgb_to_grayscale(
inpt: Union[PIL.Image.Image, torch.Tensor], num_output_channels: int = 1
) -> Union[PIL.Image.Image, torch.Tensor]:
old_color_space = features.Image.guess_color_space(inpt) if features.is_simple_tensor(inpt) else None
def rgb_to_grayscale(inpt: LegacyImageType, num_output_channels: int = 1) -> LegacyImageType:
old_color_space = (
features._image._from_tensor_shape(inpt.shape) # type: ignore[arg-type]
if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, features.Image))
else None
)
call = ", num_output_channels=3" if num_output_channels == 3 else ""
replacement = (
......@@ -44,6 +52,7 @@ def rgb_to_grayscale(
return _F.rgb_to_grayscale(inpt, num_output_channels=num_output_channels)
@torch.jit.unused
def to_tensor(inpt: Any) -> torch.Tensor:
warnings.warn(
"The function `to_tensor(...)` is deprecated and will be removed in a future release. "
......@@ -52,7 +61,7 @@ def to_tensor(inpt: Any) -> torch.Tensor:
return _F.to_tensor(inpt)
def get_image_size(inpt: Union[PIL.Image.Image, torch.Tensor, features.Image]) -> List[int]:
def get_image_size(inpt: features.ImageType) -> List[int]:
warnings.warn(
"The function `get_image_size(...)` is deprecated and will be removed in a future release. "
"Instead, please use `get_spatial_size(...)` which returns `[h, w]` instead of `[w, h]`."
......
......@@ -20,10 +20,6 @@ from torchvision.transforms.functional_tensor import _parse_pad_padding
from ._meta import convert_format_bounding_box, get_dimensions_image_pil, get_dimensions_image_tensor
# shortcut type
DType = Union[torch.Tensor, PIL.Image.Image, features._Feature]
horizontal_flip_image_tensor = _FT.hflip
horizontal_flip_image_pil = _FP.hflip
......@@ -48,13 +44,13 @@ def horizontal_flip_bounding_box(
).view(shape)
def horizontal_flip(inpt: DType) -> DType:
if isinstance(inpt, features._Feature):
def horizontal_flip(inpt: features.DType) -> features.DType:
if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, features._Feature)):
return horizontal_flip_image_tensor(inpt)
elif isinstance(inpt, features._Feature):
return inpt.horizontal_flip()
elif isinstance(inpt, PIL.Image.Image):
return horizontal_flip_image_pil(inpt)
else:
return horizontal_flip_image_tensor(inpt)
return horizontal_flip_image_pil(inpt)
vertical_flip_image_tensor = _FT.vflip
......@@ -81,13 +77,13 @@ def vertical_flip_bounding_box(
).view(shape)
def vertical_flip(inpt: DType) -> DType:
if isinstance(inpt, features._Feature):
def vertical_flip(inpt: features.DType) -> features.DType:
if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, features._Feature)):
return vertical_flip_image_tensor(inpt)
elif isinstance(inpt, features._Feature):
return inpt.vertical_flip()
elif isinstance(inpt, PIL.Image.Image):
return vertical_flip_image_pil(inpt)
else:
return vertical_flip_image_tensor(inpt)
return vertical_flip_image_pil(inpt)
# We changed the names to align them with the transforms, i.e. `RandomHorizontalFlip`. Still, `hflip` and `vflip` are
......@@ -118,6 +114,7 @@ def resize_image_tensor(
return image.view(extra_dims + (num_channels, new_height, new_width))
@torch.jit.unused
def resize_image_pil(
img: PIL.Image.Image,
size: Union[Sequence[int], int],
......@@ -157,22 +154,22 @@ def resize_bounding_box(
def resize(
inpt: DType,
inpt: features.DType,
size: List[int],
interpolation: InterpolationMode = InterpolationMode.BILINEAR,
max_size: Optional[int] = None,
antialias: Optional[bool] = None,
) -> DType:
if isinstance(inpt, features._Feature):
) -> features.DType:
if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, features._Feature)):
antialias = False if antialias is None else antialias
return resize_image_tensor(inpt, size, interpolation=interpolation, max_size=max_size, antialias=antialias)
elif isinstance(inpt, features._Feature):
antialias = False if antialias is None else antialias
return inpt.resize(size, interpolation=interpolation, max_size=max_size, antialias=antialias)
elif isinstance(inpt, PIL.Image.Image):
else:
if antialias is not None and not antialias:
warnings.warn("Anti-alias option is always applied for PIL Image input. Argument antialias is ignored.")
return resize_image_pil(inpt, size, interpolation=interpolation, max_size=max_size)
else:
antialias = False if antialias is None else antialias
return resize_image_tensor(inpt, size, interpolation=interpolation, max_size=max_size, antialias=antialias)
def _affine_parse_args(
......@@ -256,6 +253,7 @@ def affine_image_tensor(
return output.view(extra_dims + (num_channels, height, width))
@torch.jit.unused
def affine_image_pil(
img: PIL.Image.Image,
angle: float,
......@@ -263,7 +261,7 @@ def affine_image_pil(
scale: float,
shear: List[float],
interpolation: InterpolationMode = InterpolationMode.NEAREST,
fill: Optional[Union[int, float, Sequence[int], Sequence[float]]] = None,
fill: Optional[Union[int, float, List[float]]] = None,
center: Optional[List[float]] = None,
) -> PIL.Image.Image:
angle, translate, shear, center = _affine_parse_args(angle, translate, scale, shear, interpolation, center)
......@@ -422,21 +420,17 @@ def _convert_fill_arg(
def affine(
inpt: DType,
inpt: features.DType,
angle: float,
translate: List[float],
scale: float,
shear: List[float],
interpolation: InterpolationMode = InterpolationMode.NEAREST,
fill: Optional[Union[int, float, Sequence[int], Sequence[float]]] = None,
fill: Optional[Union[int, float, List[float]]] = None,
center: Optional[List[float]] = None,
) -> DType:
if isinstance(inpt, features._Feature):
return inpt.affine(
angle, translate=translate, scale=scale, shear=shear, interpolation=interpolation, fill=fill, center=center
)
elif isinstance(inpt, PIL.Image.Image):
return affine_image_pil(
) -> features.DType:
if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, features._Feature)):
return affine_image_tensor(
inpt,
angle,
translate=translate,
......@@ -446,10 +440,12 @@ def affine(
fill=fill,
center=center,
)
elif isinstance(inpt, features._Feature):
return inpt.affine(
angle, translate=translate, scale=scale, shear=shear, interpolation=interpolation, fill=fill, center=center
)
else:
fill = _convert_fill_arg(fill)
return affine_image_tensor(
return affine_image_pil(
inpt,
angle,
translate=translate,
......@@ -499,12 +495,13 @@ def rotate_image_tensor(
return img.view(extra_dims + (num_channels, new_height, new_width))
@torch.jit.unused
def rotate_image_pil(
img: PIL.Image.Image,
angle: float,
interpolation: InterpolationMode = InterpolationMode.NEAREST,
expand: bool = False,
fill: Optional[Union[int, float, Sequence[int], Sequence[float]]] = None,
fill: Optional[Union[int, float, List[float]]] = None,
center: Optional[List[float]] = None,
) -> PIL.Image.Image:
if center is not None and expand:
......@@ -567,21 +564,19 @@ def rotate_mask(
def rotate(
inpt: DType,
inpt: features.DType,
angle: float,
interpolation: InterpolationMode = InterpolationMode.NEAREST,
expand: bool = False,
fill: Optional[Union[int, float, Sequence[int], Sequence[float]]] = None,
fill: Optional[Union[int, float, List[float]]] = None,
center: Optional[List[float]] = None,
) -> DType:
if isinstance(inpt, features._Feature):
) -> features.DType:
if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, features._Feature)):
return rotate_image_tensor(inpt, angle, interpolation=interpolation, expand=expand, fill=fill, center=center)
elif isinstance(inpt, features._Feature):
return inpt.rotate(angle, interpolation=interpolation, expand=expand, fill=fill, center=center)
elif isinstance(inpt, PIL.Image.Image):
return rotate_image_pil(inpt, angle, interpolation=interpolation, expand=expand, fill=fill, center=center)
else:
fill = _convert_fill_arg(fill)
return rotate_image_tensor(inpt, angle, interpolation=interpolation, expand=expand, fill=fill, center=center)
return rotate_image_pil(inpt, angle, interpolation=interpolation, expand=expand, fill=fill, center=center)
pad_image_pil = _FP.pad
......@@ -700,23 +695,18 @@ def pad_bounding_box(
def pad(
inpt: DType,
padding: Union[int, Sequence[int]],
fill: Optional[Union[int, float, Sequence[int], Sequence[float]]] = None,
inpt: features.DType,
padding: Union[int, List[int]],
fill: Optional[Union[int, float, List[float]]] = None,
padding_mode: str = "constant",
) -> DType:
if isinstance(inpt, features._Feature):
) -> features.DType:
if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, features._Feature)):
return pad_image_tensor(inpt, padding, fill=fill, padding_mode=padding_mode)
elif isinstance(inpt, features._Feature):
return inpt.pad(padding, fill=fill, padding_mode=padding_mode)
elif isinstance(inpt, PIL.Image.Image):
return pad_image_pil(inpt, padding, fill=fill, padding_mode=padding_mode)
else:
# This cast does Sequence[int] -> List[int] and is required to make mypy happy
if not isinstance(padding, int):
padding = list(padding)
fill = _convert_fill_arg(fill)
return pad_image_tensor(inpt, padding, fill=fill, padding_mode=padding_mode)
return pad_image_pil(inpt, padding, fill=fill, padding_mode=padding_mode)
crop_image_tensor = _FT.crop
......@@ -746,13 +736,13 @@ def crop_mask(mask: torch.Tensor, top: int, left: int, height: int, width: int)
return crop_image_tensor(mask, top, left, height, width)
def crop(inpt: DType, top: int, left: int, height: int, width: int) -> DType:
if isinstance(inpt, features._Feature):
def crop(inpt: features.DType, top: int, left: int, height: int, width: int) -> features.DType:
if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, features._Feature)):
return crop_image_tensor(inpt, top, left, height, width)
elif isinstance(inpt, features._Feature):
return inpt.crop(top, left, height, width)
elif isinstance(inpt, PIL.Image.Image):
return crop_image_pil(inpt, top, left, height, width)
else:
return crop_image_tensor(inpt, top, left, height, width)
return crop_image_pil(inpt, top, left, height, width)
def perspective_image_tensor(
......@@ -764,6 +754,7 @@ def perspective_image_tensor(
return _FT.perspective(img, perspective_coeffs, interpolation=interpolation.value, fill=fill)
@torch.jit.unused
def perspective_image_pil(
img: PIL.Image.Image,
perspective_coeffs: List[float],
......@@ -876,22 +867,20 @@ def perspective_mask(mask: torch.Tensor, perspective_coeffs: List[float]) -> tor
def perspective(
inpt: DType,
inpt: features.DType,
startpoints: List[List[int]],
endpoints: List[List[int]],
interpolation: InterpolationMode = InterpolationMode.BILINEAR,
fill: Optional[Union[int, float, Sequence[int], Sequence[float]]] = None,
) -> DType:
fill: Optional[Union[int, float, List[float]]] = None,
) -> features.DType:
perspective_coeffs = _get_perspective_coeffs(startpoints, endpoints)
if isinstance(inpt, features._Feature):
if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, features._Feature)):
return perspective_image_tensor(inpt, perspective_coeffs, interpolation=interpolation, fill=fill)
elif isinstance(inpt, features._Feature):
return inpt.perspective(perspective_coeffs, interpolation=interpolation, fill=fill)
elif isinstance(inpt, PIL.Image.Image):
return perspective_image_pil(inpt, perspective_coeffs, interpolation=interpolation, fill=fill)
else:
fill = _convert_fill_arg(fill)
return perspective_image_tensor(inpt, perspective_coeffs, interpolation=interpolation, fill=fill)
return perspective_image_pil(inpt, perspective_coeffs, interpolation=interpolation, fill=fill)
def elastic_image_tensor(
......@@ -903,15 +892,14 @@ def elastic_image_tensor(
return _FT.elastic_transform(img, displacement, interpolation=interpolation.value, fill=fill)
@torch.jit.unused
def elastic_image_pil(
img: PIL.Image.Image,
displacement: torch.Tensor,
interpolation: InterpolationMode = InterpolationMode.BILINEAR,
fill: Optional[Union[int, float, Sequence[int], Sequence[float]]] = None,
fill: Optional[Union[int, float, List[float]]] = None,
) -> PIL.Image.Image:
t_img = pil_to_tensor(img)
fill = _convert_fill_arg(fill)
output = elastic_image_tensor(t_img, displacement, interpolation=interpolation, fill=fill)
return to_pil_image(output, mode=img.mode)
......@@ -972,19 +960,17 @@ def elastic_mask(mask: torch.Tensor, displacement: torch.Tensor) -> torch.Tensor
def elastic(
inpt: DType,
inpt: features.DType,
displacement: torch.Tensor,
interpolation: InterpolationMode = InterpolationMode.BILINEAR,
fill: Optional[Union[int, float, Sequence[int], Sequence[float]]] = None,
) -> DType:
if isinstance(inpt, features._Feature):
fill: Optional[Union[int, float, List[float]]] = None,
) -> features.DType:
if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, features._Feature)):
return elastic_image_tensor(inpt, displacement, interpolation=interpolation, fill=fill)
elif isinstance(inpt, features._Feature):
return inpt.elastic(displacement, interpolation=interpolation, fill=fill)
elif isinstance(inpt, PIL.Image.Image):
return elastic_image_pil(inpt, displacement, interpolation=interpolation, fill=fill)
else:
fill = _convert_fill_arg(fill)
return elastic_image_tensor(inpt, displacement, interpolation=interpolation, fill=fill)
return elastic_image_pil(inpt, displacement, interpolation=interpolation, fill=fill)
elastic_transform = elastic
......@@ -1032,6 +1018,7 @@ def center_crop_image_tensor(img: torch.Tensor, output_size: List[int]) -> torch
return crop_image_tensor(img, crop_top, crop_left, crop_height, crop_width)
@torch.jit.unused
def center_crop_image_pil(img: PIL.Image.Image, output_size: List[int]) -> PIL.Image.Image:
crop_height, crop_width = _center_crop_parse_output_size(output_size)
_, image_height, image_width = get_dimensions_image_pil(img)
......@@ -1074,13 +1061,13 @@ def center_crop_mask(mask: torch.Tensor, output_size: List[int]) -> torch.Tensor
return output
def center_crop(inpt: DType, output_size: List[int]) -> DType:
if isinstance(inpt, features._Feature):
def center_crop(inpt: features.DType, output_size: List[int]) -> features.DType:
if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, features._Feature)):
return center_crop_image_tensor(inpt, output_size)
elif isinstance(inpt, features._Feature):
return inpt.center_crop(output_size)
elif isinstance(inpt, PIL.Image.Image):
return center_crop_image_pil(inpt, output_size)
else:
return center_crop_image_tensor(inpt, output_size)
return center_crop_image_pil(inpt, output_size)
def resized_crop_image_tensor(
......@@ -1097,6 +1084,7 @@ def resized_crop_image_tensor(
return resize_image_tensor(img, size, interpolation=interpolation, antialias=antialias)
@torch.jit.unused
def resized_crop_image_pil(
img: PIL.Image.Image,
top: int,
......@@ -1136,7 +1124,7 @@ def resized_crop_mask(
def resized_crop(
inpt: DType,
inpt: features.DType,
top: int,
left: int,
height: int,
......@@ -1144,17 +1132,17 @@ def resized_crop(
size: List[int],
interpolation: InterpolationMode = InterpolationMode.BILINEAR,
antialias: Optional[bool] = None,
) -> DType:
if isinstance(inpt, features._Feature):
antialias = False if antialias is None else antialias
return inpt.resized_crop(top, left, height, width, antialias=antialias, size=size, interpolation=interpolation)
elif isinstance(inpt, PIL.Image.Image):
return resized_crop_image_pil(inpt, top, left, height, width, size=size, interpolation=interpolation)
else:
) -> features.DType:
if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, features._Feature)):
antialias = False if antialias is None else antialias
return resized_crop_image_tensor(
inpt, top, left, height, width, antialias=antialias, size=size, interpolation=interpolation
)
elif isinstance(inpt, features._Feature):
antialias = False if antialias is None else antialias
return inpt.resized_crop(top, left, height, width, antialias=antialias, size=size, interpolation=interpolation)
else:
return resized_crop_image_pil(inpt, top, left, height, width, size=size, interpolation=interpolation)
def _parse_five_crop_size(size: List[int]) -> List[int]:
......@@ -1188,6 +1176,7 @@ def five_crop_image_tensor(
return tl, tr, bl, br, center
@torch.jit.unused
def five_crop_image_pil(
img: PIL.Image.Image, size: List[int]
) -> Tuple[PIL.Image.Image, PIL.Image.Image, PIL.Image.Image, PIL.Image.Image, PIL.Image.Image]:
......@@ -1207,11 +1196,13 @@ def five_crop_image_pil(
return tl, tr, bl, br, center
def five_crop(inpt: DType, size: List[int]) -> Tuple[DType, DType, DType, DType, DType]:
# TODO: consider breaking BC here to return List[DType] to align this op with `ten_crop`
def five_crop(
inpt: features.ImageType, size: List[int]
) -> Tuple[features.ImageType, features.ImageType, features.ImageType, features.ImageType, features.ImageType]:
# TODO: consider breaking BC here to return List[features.ImageType] to align this op with `ten_crop`
if isinstance(inpt, torch.Tensor):
output = five_crop_image_tensor(inpt, size)
if isinstance(inpt, features.Image):
if not torch.jit.is_scripting() and isinstance(inpt, features.Image):
output = tuple(features.Image.new_like(inpt, item) for item in output) # type: ignore[assignment]
return output
else: # isinstance(inpt, PIL.Image.Image):
......@@ -1231,6 +1222,7 @@ def ten_crop_image_tensor(img: torch.Tensor, size: List[int], vertical_flip: boo
return [tl, tr, bl, br, center, tl_flip, tr_flip, bl_flip, br_flip, center_flip]
@torch.jit.unused
def ten_crop_image_pil(img: PIL.Image.Image, size: List[int], vertical_flip: bool = False) -> List[PIL.Image.Image]:
tl, tr, bl, br, center = five_crop_image_pil(img, size)
......@@ -1244,10 +1236,10 @@ def ten_crop_image_pil(img: PIL.Image.Image, size: List[int], vertical_flip: boo
return [tl, tr, bl, br, center, tl_flip, tr_flip, bl_flip, br_flip, center_flip]
def ten_crop(inpt: DType, size: List[int], vertical_flip: bool = False) -> List[DType]:
def ten_crop(inpt: features.ImageType, size: List[int], vertical_flip: bool = False) -> List[features.ImageType]:
if isinstance(inpt, torch.Tensor):
output = ten_crop_image_tensor(inpt, size, vertical_flip=vertical_flip)
if isinstance(inpt, features.Image):
if not torch.jit.is_scripting() and isinstance(inpt, features.Image):
output = [features.Image.new_like(inpt, item) for item in output]
return output
else: # isinstance(inpt, PIL.Image.Image):
......
from typing import Any, List, Optional, Tuple, Union
from typing import cast, List, Optional, Tuple
import PIL.Image
import torch
from torchvision.prototype import features
from torchvision.prototype.features import BoundingBoxFormat, ColorSpace, Image
from torchvision.prototype.features import BoundingBoxFormat, ColorSpace
from torchvision.transforms import functional_pil as _FP, functional_tensor as _FT
get_dimensions_image_tensor = _FT.get_dimensions
get_dimensions_image_pil = _FP.get_dimensions
def get_chw(image: Union[PIL.Image.Image, torch.Tensor, features.Image]) -> Tuple[int, int, int]:
if isinstance(image, features.Image):
# TODO: Should this be prefixed with `_` similar to other methods that don't get exposed by init?
def get_chw(image: features.ImageType) -> Tuple[int, int, int]:
if isinstance(image, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(image, features.Image)):
channels, height, width = get_dimensions_image_tensor(image)
elif isinstance(image, features.Image):
channels = image.num_channels
height, width = image.image_size
elif features.is_simple_tensor(image):
channels, height, width = get_dimensions_image_tensor(image)
elif isinstance(image, PIL.Image.Image):
else: # isinstance(image, PIL.Image.Image)
channels, height, width = get_dimensions_image_pil(image)
else:
raise TypeError(f"unable to get image dimensions from object of type {type(image).__name__}")
return channels, height, width
......@@ -30,11 +29,11 @@ def get_chw(image: Union[PIL.Image.Image, torch.Tensor, features.Image]) -> Tupl
# detailed above.
def get_dimensions(image: Union[PIL.Image.Image, torch.Tensor, features.Image]) -> List[int]:
def get_dimensions(image: features.ImageType) -> List[int]:
return list(get_chw(image))
def get_num_channels(image: Union[PIL.Image.Image, torch.Tensor, features.Image]) -> int:
def get_num_channels(image: features.ImageType) -> int:
num_channels, *_ = get_chw(image)
return num_channels
......@@ -44,7 +43,7 @@ def get_num_channels(image: Union[PIL.Image.Image, torch.Tensor, features.Image]
get_image_num_channels = get_num_channels
def get_spatial_size(image: Union[PIL.Image.Image, torch.Tensor, features.Image]) -> List[int]:
def get_spatial_size(image: features.ImageType) -> List[int]:
_, *size = get_chw(image)
return size
......@@ -192,6 +191,7 @@ _COLOR_SPACE_TO_PIL_MODE = {
}
@torch.jit.unused
def convert_color_space_image_pil(
image: PIL.Image.Image, color_space: ColorSpace, copy: bool = True
) -> PIL.Image.Image:
......@@ -208,17 +208,12 @@ def convert_color_space_image_pil(
def convert_color_space(
inpt: Union[PIL.Image.Image, torch.Tensor, features._Feature],
*,
inpt: features.ImageType,
color_space: ColorSpace,
old_color_space: Optional[ColorSpace] = None,
copy: bool = True,
) -> Any:
if isinstance(inpt, Image):
return inpt.to_color_space(color_space, copy=copy)
elif isinstance(inpt, PIL.Image.Image):
return convert_color_space_image_pil(inpt, color_space, copy=copy)
else:
) -> features.ImageType:
if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, features.Image)):
if old_color_space is None:
raise RuntimeError(
"In order to convert the color space of simple tensor images, "
......@@ -227,3 +222,7 @@ def convert_color_space(
return convert_color_space_image_tensor(
inpt, old_color_space=old_color_space, new_color_space=color_space, copy=copy
)
elif isinstance(inpt, features.Image):
return inpt.to_color_space(color_space, copy=copy)
else:
return cast(features.ImageType, convert_color_space_image_pil(inpt, color_space, copy=copy))
from typing import List, Optional, Union
from typing import List, Optional
import PIL.Image
import torch
......@@ -7,16 +7,15 @@ from torchvision.transforms import functional_tensor as _FT
from torchvision.transforms.functional import pil_to_tensor, to_pil_image
# shortcut type
DType = Union[torch.Tensor, PIL.Image.Image, features._Feature]
# Due to torch.jit.script limitation we keep TensorImageType as torch.Tensor
# instead of Union[torch.Tensor, features.Image]
TensorImageType = torch.Tensor
normalize_image_tensor = _FT.normalize
def normalize(
inpt: Union[torch.Tensor, features.Image], mean: List[float], std: List[float], inplace: bool = False
) -> torch.Tensor:
def normalize(inpt: TensorImageType, mean: List[float], std: List[float], inplace: bool = False) -> torch.Tensor:
if not isinstance(inpt, torch.Tensor):
raise TypeError(f"img should be Tensor Image. Got {type(inpt)}")
else:
......@@ -54,6 +53,7 @@ def gaussian_blur_image_tensor(
return _FT.gaussian_blur(img, kernel_size, sigma)
@torch.jit.unused
def gaussian_blur_image_pil(
img: PIL.Image.Image, kernel_size: List[int], sigma: Optional[List[float]] = None
) -> PIL.Image.Image:
......@@ -62,10 +62,10 @@ def gaussian_blur_image_pil(
return to_pil_image(output, mode=img.mode)
def gaussian_blur(inpt: DType, kernel_size: List[int], sigma: Optional[List[float]] = None) -> DType:
if isinstance(inpt, features._Feature):
def gaussian_blur(inpt: features.DType, kernel_size: List[int], sigma: Optional[List[float]] = None) -> features.DType:
if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, features._Feature)):
return gaussian_blur_image_tensor(inpt, kernel_size=kernel_size, sigma=sigma)
elif isinstance(inpt, features._Feature):
return inpt.gaussian_blur(kernel_size=kernel_size, sigma=sigma)
elif isinstance(inpt, PIL.Image.Image):
return gaussian_blur_image_pil(inpt, kernel_size=kernel_size, sigma=sigma)
else:
return gaussian_blur_image_tensor(inpt, kernel_size=kernel_size, sigma=sigma)
return gaussian_blur_image_pil(inpt, kernel_size=kernel_size, sigma=sigma)
......@@ -10,6 +10,7 @@ from torchvision.prototype.utils._internal import ReadOnlyTensorBuffer
from torchvision.transforms import functional as _F
@torch.jit.unused
def decode_image_with_pil(encoded_image: torch.Tensor) -> features.Image:
image = torch.as_tensor(np.array(PIL.Image.open(ReadOnlyTensorBuffer(encoded_image)), copy=True))
if image.ndim == 2:
......@@ -17,11 +18,13 @@ def decode_image_with_pil(encoded_image: torch.Tensor) -> features.Image:
return features.Image(image.permute(2, 0, 1))
@torch.jit.unused
def decode_video_with_av(encoded_video: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, Dict[str, Any]]:
with unittest.mock.patch("torchvision.io.video.os.path.exists", return_value=True):
return read_video(ReadOnlyTensorBuffer(encoded_video)) # type: ignore[arg-type]
@torch.jit.unused
def to_image_tensor(image: Union[torch.Tensor, PIL.Image.Image, np.ndarray]) -> features.Image:
if isinstance(image, np.ndarray):
output = torch.from_numpy(image)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment