"...transforms/git@developer.sourcefind.cn:OpenDAS/vision.git" did not exist on "d4d20f01e191dbacd0a0e6c8a5db5062222753ba"
Unverified Commit f71c4308 authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

simplify dispatcher if-elif (#7084)

parent 69ae61a1
...@@ -32,7 +32,6 @@ no_implicit_optional = True ...@@ -32,7 +32,6 @@ no_implicit_optional = True
; warnings ; warnings
warn_unused_ignores = True warn_unused_ignores = True
warn_return_any = True
; miscellaneous strictness flags ; miscellaneous strictness flags
allow_redefinition = True allow_redefinition = True
......
...@@ -46,7 +46,7 @@ class ToImageTensor(Transform): ...@@ -46,7 +46,7 @@ class ToImageTensor(Transform):
def _transform( def _transform(
self, inpt: Union[torch.Tensor, PIL.Image.Image, np.ndarray], params: Dict[str, Any] self, inpt: Union[torch.Tensor, PIL.Image.Image, np.ndarray], params: Dict[str, Any]
) -> datapoints.Image: ) -> datapoints.Image:
return F.to_image_tensor(inpt) # type: ignore[no-any-return] return F.to_image_tensor(inpt)
class ToImagePIL(Transform): class ToImagePIL(Transform):
......
# TODO: Add _log_api_usage_once() in all mid-level kernels. If they remain not jit-scriptable we can use decorators # TODO: Add _log_api_usage_once() in all mid-level kernels. If they remain not jit-scriptable we can use decorators
from torchvision.transforms import InterpolationMode # usort: skip from torchvision.transforms import InterpolationMode # usort: skip
from ._utils import is_simple_tensor # usort: skip
from ._meta import ( from ._meta import (
clamp_bounding_box, clamp_bounding_box,
convert_format_bounding_box, convert_format_bounding_box,
......
...@@ -7,6 +7,8 @@ from torchvision.prototype import datapoints ...@@ -7,6 +7,8 @@ from torchvision.prototype import datapoints
from torchvision.transforms.functional import pil_to_tensor, to_pil_image from torchvision.transforms.functional import pil_to_tensor, to_pil_image
from torchvision.utils import _log_api_usage_once from torchvision.utils import _log_api_usage_once
from ._utils import is_simple_tensor
def erase_image_tensor( def erase_image_tensor(
image: torch.Tensor, i: int, j: int, h: int, w: int, v: torch.Tensor, inplace: bool = False image: torch.Tensor, i: int, j: int, h: int, w: int, v: torch.Tensor, inplace: bool = False
...@@ -45,9 +47,7 @@ def erase( ...@@ -45,9 +47,7 @@ def erase(
if not torch.jit.is_scripting(): if not torch.jit.is_scripting():
_log_api_usage_once(erase) _log_api_usage_once(erase)
if isinstance(inpt, torch.Tensor) and ( if torch.jit.is_scripting() or is_simple_tensor(inpt):
torch.jit.is_scripting() or not isinstance(inpt, (datapoints.Image, datapoints.Video))
):
return erase_image_tensor(inpt, i=i, j=j, h=h, w=w, v=v, inplace=inplace) return erase_image_tensor(inpt, i=i, j=j, h=h, w=w, v=v, inplace=inplace)
elif isinstance(inpt, datapoints.Image): elif isinstance(inpt, datapoints.Image):
output = erase_image_tensor(inpt.as_subclass(torch.Tensor), i=i, j=j, h=h, w=w, v=v, inplace=inplace) output = erase_image_tensor(inpt.as_subclass(torch.Tensor), i=i, j=j, h=h, w=w, v=v, inplace=inplace)
......
...@@ -8,6 +8,7 @@ from torchvision.transforms.functional_tensor import _max_value ...@@ -8,6 +8,7 @@ from torchvision.transforms.functional_tensor import _max_value
from torchvision.utils import _log_api_usage_once from torchvision.utils import _log_api_usage_once
from ._meta import _num_value_bits, _rgb_to_gray, convert_dtype_image_tensor from ._meta import _num_value_bits, _rgb_to_gray, convert_dtype_image_tensor
from ._utils import is_simple_tensor
def _blend(image1: torch.Tensor, image2: torch.Tensor, ratio: float) -> torch.Tensor: def _blend(image1: torch.Tensor, image2: torch.Tensor, ratio: float) -> torch.Tensor:
...@@ -43,9 +44,7 @@ def adjust_brightness(inpt: datapoints.InputTypeJIT, brightness_factor: float) - ...@@ -43,9 +44,7 @@ def adjust_brightness(inpt: datapoints.InputTypeJIT, brightness_factor: float) -
if not torch.jit.is_scripting(): if not torch.jit.is_scripting():
_log_api_usage_once(adjust_brightness) _log_api_usage_once(adjust_brightness)
if isinstance(inpt, torch.Tensor) and ( if torch.jit.is_scripting() or is_simple_tensor(inpt):
torch.jit.is_scripting() or not isinstance(inpt, datapoints._datapoint.Datapoint)
):
return adjust_brightness_image_tensor(inpt, brightness_factor=brightness_factor) return adjust_brightness_image_tensor(inpt, brightness_factor=brightness_factor)
elif isinstance(inpt, datapoints._datapoint.Datapoint): elif isinstance(inpt, datapoints._datapoint.Datapoint):
return inpt.adjust_brightness(brightness_factor=brightness_factor) return inpt.adjust_brightness(brightness_factor=brightness_factor)
...@@ -131,9 +130,7 @@ def adjust_contrast(inpt: datapoints.InputTypeJIT, contrast_factor: float) -> da ...@@ -131,9 +130,7 @@ def adjust_contrast(inpt: datapoints.InputTypeJIT, contrast_factor: float) -> da
if not torch.jit.is_scripting(): if not torch.jit.is_scripting():
_log_api_usage_once(adjust_contrast) _log_api_usage_once(adjust_contrast)
if isinstance(inpt, torch.Tensor) and ( if torch.jit.is_scripting() or is_simple_tensor(inpt):
torch.jit.is_scripting() or not isinstance(inpt, datapoints._datapoint.Datapoint)
):
return adjust_contrast_image_tensor(inpt, contrast_factor=contrast_factor) return adjust_contrast_image_tensor(inpt, contrast_factor=contrast_factor)
elif isinstance(inpt, datapoints._datapoint.Datapoint): elif isinstance(inpt, datapoints._datapoint.Datapoint):
return inpt.adjust_contrast(contrast_factor=contrast_factor) return inpt.adjust_contrast(contrast_factor=contrast_factor)
...@@ -326,9 +323,7 @@ def adjust_hue(inpt: datapoints.InputTypeJIT, hue_factor: float) -> datapoints.I ...@@ -326,9 +323,7 @@ def adjust_hue(inpt: datapoints.InputTypeJIT, hue_factor: float) -> datapoints.I
if not torch.jit.is_scripting(): if not torch.jit.is_scripting():
_log_api_usage_once(adjust_hue) _log_api_usage_once(adjust_hue)
if isinstance(inpt, torch.Tensor) and ( if torch.jit.is_scripting() or is_simple_tensor(inpt):
torch.jit.is_scripting() or not isinstance(inpt, datapoints._datapoint.Datapoint)
):
return adjust_hue_image_tensor(inpt, hue_factor=hue_factor) return adjust_hue_image_tensor(inpt, hue_factor=hue_factor)
elif isinstance(inpt, datapoints._datapoint.Datapoint): elif isinstance(inpt, datapoints._datapoint.Datapoint):
return inpt.adjust_hue(hue_factor=hue_factor) return inpt.adjust_hue(hue_factor=hue_factor)
...@@ -371,9 +366,7 @@ def adjust_gamma(inpt: datapoints.InputTypeJIT, gamma: float, gain: float = 1) - ...@@ -371,9 +366,7 @@ def adjust_gamma(inpt: datapoints.InputTypeJIT, gamma: float, gain: float = 1) -
if not torch.jit.is_scripting(): if not torch.jit.is_scripting():
_log_api_usage_once(adjust_gamma) _log_api_usage_once(adjust_gamma)
if isinstance(inpt, torch.Tensor) and ( if torch.jit.is_scripting() or is_simple_tensor(inpt):
torch.jit.is_scripting() or not isinstance(inpt, datapoints._datapoint.Datapoint)
):
return adjust_gamma_image_tensor(inpt, gamma=gamma, gain=gain) return adjust_gamma_image_tensor(inpt, gamma=gamma, gain=gain)
elif isinstance(inpt, datapoints._datapoint.Datapoint): elif isinstance(inpt, datapoints._datapoint.Datapoint):
return inpt.adjust_gamma(gamma=gamma, gain=gain) return inpt.adjust_gamma(gamma=gamma, gain=gain)
...@@ -410,9 +403,7 @@ def posterize(inpt: datapoints.InputTypeJIT, bits: int) -> datapoints.InputTypeJ ...@@ -410,9 +403,7 @@ def posterize(inpt: datapoints.InputTypeJIT, bits: int) -> datapoints.InputTypeJ
if not torch.jit.is_scripting(): if not torch.jit.is_scripting():
_log_api_usage_once(posterize) _log_api_usage_once(posterize)
if isinstance(inpt, torch.Tensor) and ( if torch.jit.is_scripting() or is_simple_tensor(inpt):
torch.jit.is_scripting() or not isinstance(inpt, datapoints._datapoint.Datapoint)
):
return posterize_image_tensor(inpt, bits=bits) return posterize_image_tensor(inpt, bits=bits)
elif isinstance(inpt, datapoints._datapoint.Datapoint): elif isinstance(inpt, datapoints._datapoint.Datapoint):
return inpt.posterize(bits=bits) return inpt.posterize(bits=bits)
...@@ -443,9 +434,7 @@ def solarize(inpt: datapoints.InputTypeJIT, threshold: float) -> datapoints.Inpu ...@@ -443,9 +434,7 @@ def solarize(inpt: datapoints.InputTypeJIT, threshold: float) -> datapoints.Inpu
if not torch.jit.is_scripting(): if not torch.jit.is_scripting():
_log_api_usage_once(solarize) _log_api_usage_once(solarize)
if isinstance(inpt, torch.Tensor) and ( if torch.jit.is_scripting() or is_simple_tensor(inpt):
torch.jit.is_scripting() or not isinstance(inpt, datapoints._datapoint.Datapoint)
):
return solarize_image_tensor(inpt, threshold=threshold) return solarize_image_tensor(inpt, threshold=threshold)
elif isinstance(inpt, datapoints._datapoint.Datapoint): elif isinstance(inpt, datapoints._datapoint.Datapoint):
return inpt.solarize(threshold=threshold) return inpt.solarize(threshold=threshold)
...@@ -498,9 +487,7 @@ def autocontrast(inpt: datapoints.InputTypeJIT) -> datapoints.InputTypeJIT: ...@@ -498,9 +487,7 @@ def autocontrast(inpt: datapoints.InputTypeJIT) -> datapoints.InputTypeJIT:
if not torch.jit.is_scripting(): if not torch.jit.is_scripting():
_log_api_usage_once(autocontrast) _log_api_usage_once(autocontrast)
if isinstance(inpt, torch.Tensor) and ( if torch.jit.is_scripting() or is_simple_tensor(inpt):
torch.jit.is_scripting() or not isinstance(inpt, datapoints._datapoint.Datapoint)
):
return autocontrast_image_tensor(inpt) return autocontrast_image_tensor(inpt)
elif isinstance(inpt, datapoints._datapoint.Datapoint): elif isinstance(inpt, datapoints._datapoint.Datapoint):
return inpt.autocontrast() return inpt.autocontrast()
...@@ -593,9 +580,7 @@ def equalize(inpt: datapoints.InputTypeJIT) -> datapoints.InputTypeJIT: ...@@ -593,9 +580,7 @@ def equalize(inpt: datapoints.InputTypeJIT) -> datapoints.InputTypeJIT:
if not torch.jit.is_scripting(): if not torch.jit.is_scripting():
_log_api_usage_once(equalize) _log_api_usage_once(equalize)
if isinstance(inpt, torch.Tensor) and ( if torch.jit.is_scripting() or is_simple_tensor(inpt):
torch.jit.is_scripting() or not isinstance(inpt, datapoints._datapoint.Datapoint)
):
return equalize_image_tensor(inpt) return equalize_image_tensor(inpt)
elif isinstance(inpt, datapoints._datapoint.Datapoint): elif isinstance(inpt, datapoints._datapoint.Datapoint):
return inpt.equalize() return inpt.equalize()
...@@ -610,7 +595,7 @@ def equalize(inpt: datapoints.InputTypeJIT) -> datapoints.InputTypeJIT: ...@@ -610,7 +595,7 @@ def equalize(inpt: datapoints.InputTypeJIT) -> datapoints.InputTypeJIT:
def invert_image_tensor(image: torch.Tensor) -> torch.Tensor: def invert_image_tensor(image: torch.Tensor) -> torch.Tensor:
if image.is_floating_point(): if image.is_floating_point():
return 1.0 - image # type: ignore[no-any-return] return 1.0 - image
elif image.dtype == torch.uint8: elif image.dtype == torch.uint8:
return image.bitwise_not() return image.bitwise_not()
else: # signed integer dtypes else: # signed integer dtypes
...@@ -629,9 +614,7 @@ def invert(inpt: datapoints.InputTypeJIT) -> datapoints.InputTypeJIT: ...@@ -629,9 +614,7 @@ def invert(inpt: datapoints.InputTypeJIT) -> datapoints.InputTypeJIT:
if not torch.jit.is_scripting(): if not torch.jit.is_scripting():
_log_api_usage_once(invert) _log_api_usage_once(invert)
if isinstance(inpt, torch.Tensor) and ( if torch.jit.is_scripting() or is_simple_tensor(inpt):
torch.jit.is_scripting() or not isinstance(inpt, datapoints._datapoint.Datapoint)
):
return invert_image_tensor(inpt) return invert_image_tensor(inpt)
elif isinstance(inpt, datapoints._datapoint.Datapoint): elif isinstance(inpt, datapoints._datapoint.Datapoint):
return inpt.invert() return inpt.invert()
......
...@@ -7,6 +7,8 @@ import torch ...@@ -7,6 +7,8 @@ import torch
from torchvision.prototype import datapoints from torchvision.prototype import datapoints
from torchvision.transforms import functional as _F from torchvision.transforms import functional as _F
from ._utils import is_simple_tensor
@torch.jit.unused @torch.jit.unused
def to_grayscale(inpt: PIL.Image.Image, num_output_channels: int = 1) -> PIL.Image.Image: def to_grayscale(inpt: PIL.Image.Image, num_output_channels: int = 1) -> PIL.Image.Image:
...@@ -25,14 +27,14 @@ def to_grayscale(inpt: PIL.Image.Image, num_output_channels: int = 1) -> PIL.Ima ...@@ -25,14 +27,14 @@ def to_grayscale(inpt: PIL.Image.Image, num_output_channels: int = 1) -> PIL.Ima
def rgb_to_grayscale( def rgb_to_grayscale(
inpt: Union[datapoints.ImageTypeJIT, datapoints.VideoTypeJIT], num_output_channels: int = 1 inpt: Union[datapoints.ImageTypeJIT, datapoints.VideoTypeJIT], num_output_channels: int = 1
) -> Union[datapoints.ImageTypeJIT, datapoints.VideoTypeJIT]: ) -> Union[datapoints.ImageTypeJIT, datapoints.VideoTypeJIT]:
if not torch.jit.is_scripting() and isinstance(inpt, (datapoints.Image, datapoints.Video)): if torch.jit.is_scripting() or is_simple_tensor(inpt):
inpt = inpt.as_subclass(torch.Tensor)
old_color_space = None
elif isinstance(inpt, torch.Tensor):
old_color_space = datapoints._image._from_tensor_shape(inpt.shape) # type: ignore[arg-type] old_color_space = datapoints._image._from_tensor_shape(inpt.shape) # type: ignore[arg-type]
else: else:
old_color_space = None old_color_space = None
if isinstance(inpt, (datapoints.Image, datapoints.Video)):
inpt = inpt.as_subclass(torch.Tensor)
call = ", num_output_channels=3" if num_output_channels == 3 else "" call = ", num_output_channels=3" if num_output_channels == 3 else ""
replacement = ( replacement = (
f"convert_color_space(..., color_space=datapoints.ColorSpace.GRAY" f"convert_color_space(..., color_space=datapoints.ColorSpace.GRAY"
......
...@@ -23,6 +23,8 @@ from torchvision.utils import _log_api_usage_once ...@@ -23,6 +23,8 @@ from torchvision.utils import _log_api_usage_once
from ._meta import convert_format_bounding_box, get_spatial_size_image_pil from ._meta import convert_format_bounding_box, get_spatial_size_image_pil
from ._utils import is_simple_tensor
def horizontal_flip_image_tensor(image: torch.Tensor) -> torch.Tensor: def horizontal_flip_image_tensor(image: torch.Tensor) -> torch.Tensor:
return image.flip(-1) return image.flip(-1)
...@@ -60,9 +62,7 @@ def horizontal_flip(inpt: datapoints.InputTypeJIT) -> datapoints.InputTypeJIT: ...@@ -60,9 +62,7 @@ def horizontal_flip(inpt: datapoints.InputTypeJIT) -> datapoints.InputTypeJIT:
if not torch.jit.is_scripting(): if not torch.jit.is_scripting():
_log_api_usage_once(horizontal_flip) _log_api_usage_once(horizontal_flip)
if isinstance(inpt, torch.Tensor) and ( if torch.jit.is_scripting() or is_simple_tensor(inpt):
torch.jit.is_scripting() or not isinstance(inpt, datapoints._datapoint.Datapoint)
):
return horizontal_flip_image_tensor(inpt) return horizontal_flip_image_tensor(inpt)
elif isinstance(inpt, datapoints._datapoint.Datapoint): elif isinstance(inpt, datapoints._datapoint.Datapoint):
return inpt.horizontal_flip() return inpt.horizontal_flip()
...@@ -111,9 +111,7 @@ def vertical_flip(inpt: datapoints.InputTypeJIT) -> datapoints.InputTypeJIT: ...@@ -111,9 +111,7 @@ def vertical_flip(inpt: datapoints.InputTypeJIT) -> datapoints.InputTypeJIT:
if not torch.jit.is_scripting(): if not torch.jit.is_scripting():
_log_api_usage_once(vertical_flip) _log_api_usage_once(vertical_flip)
if isinstance(inpt, torch.Tensor) and ( if torch.jit.is_scripting() or is_simple_tensor(inpt):
torch.jit.is_scripting() or not isinstance(inpt, datapoints._datapoint.Datapoint)
):
return vertical_flip_image_tensor(inpt) return vertical_flip_image_tensor(inpt)
elif isinstance(inpt, datapoints._datapoint.Datapoint): elif isinstance(inpt, datapoints._datapoint.Datapoint):
return inpt.vertical_flip() return inpt.vertical_flip()
...@@ -241,9 +239,7 @@ def resize( ...@@ -241,9 +239,7 @@ def resize(
) -> datapoints.InputTypeJIT: ) -> datapoints.InputTypeJIT:
if not torch.jit.is_scripting(): if not torch.jit.is_scripting():
_log_api_usage_once(resize) _log_api_usage_once(resize)
if isinstance(inpt, torch.Tensor) and ( if torch.jit.is_scripting() or is_simple_tensor(inpt):
torch.jit.is_scripting() or not isinstance(inpt, datapoints._datapoint.Datapoint)
):
return resize_image_tensor(inpt, size, interpolation=interpolation, max_size=max_size, antialias=antialias) return resize_image_tensor(inpt, size, interpolation=interpolation, max_size=max_size, antialias=antialias)
elif isinstance(inpt, datapoints._datapoint.Datapoint): elif isinstance(inpt, datapoints._datapoint.Datapoint):
return inpt.resize(size, interpolation=interpolation, max_size=max_size, antialias=antialias) return inpt.resize(size, interpolation=interpolation, max_size=max_size, antialias=antialias)
...@@ -744,9 +740,7 @@ def affine( ...@@ -744,9 +740,7 @@ def affine(
_log_api_usage_once(affine) _log_api_usage_once(affine)
# TODO: consider deprecating integers from angle and shear on the future # TODO: consider deprecating integers from angle and shear on the future
if isinstance(inpt, torch.Tensor) and ( if torch.jit.is_scripting() or is_simple_tensor(inpt):
torch.jit.is_scripting() or not isinstance(inpt, datapoints._datapoint.Datapoint)
):
return affine_image_tensor( return affine_image_tensor(
inpt, inpt,
angle, angle,
...@@ -929,9 +923,7 @@ def rotate( ...@@ -929,9 +923,7 @@ def rotate(
if not torch.jit.is_scripting(): if not torch.jit.is_scripting():
_log_api_usage_once(rotate) _log_api_usage_once(rotate)
if isinstance(inpt, torch.Tensor) and ( if torch.jit.is_scripting() or is_simple_tensor(inpt):
torch.jit.is_scripting() or not isinstance(inpt, datapoints._datapoint.Datapoint)
):
return rotate_image_tensor(inpt, angle, interpolation=interpolation, expand=expand, fill=fill, center=center) return rotate_image_tensor(inpt, angle, interpolation=interpolation, expand=expand, fill=fill, center=center)
elif isinstance(inpt, datapoints._datapoint.Datapoint): elif isinstance(inpt, datapoints._datapoint.Datapoint):
return inpt.rotate(angle, interpolation=interpolation, expand=expand, fill=fill, center=center) return inpt.rotate(angle, interpolation=interpolation, expand=expand, fill=fill, center=center)
...@@ -1139,9 +1131,7 @@ def pad( ...@@ -1139,9 +1131,7 @@ def pad(
if not torch.jit.is_scripting(): if not torch.jit.is_scripting():
_log_api_usage_once(pad) _log_api_usage_once(pad)
if isinstance(inpt, torch.Tensor) and ( if torch.jit.is_scripting() or is_simple_tensor(inpt):
torch.jit.is_scripting() or not isinstance(inpt, datapoints._datapoint.Datapoint)
):
return pad_image_tensor(inpt, padding, fill=fill, padding_mode=padding_mode) return pad_image_tensor(inpt, padding, fill=fill, padding_mode=padding_mode)
elif isinstance(inpt, datapoints._datapoint.Datapoint): elif isinstance(inpt, datapoints._datapoint.Datapoint):
...@@ -1219,9 +1209,7 @@ def crop(inpt: datapoints.InputTypeJIT, top: int, left: int, height: int, width: ...@@ -1219,9 +1209,7 @@ def crop(inpt: datapoints.InputTypeJIT, top: int, left: int, height: int, width:
if not torch.jit.is_scripting(): if not torch.jit.is_scripting():
_log_api_usage_once(crop) _log_api_usage_once(crop)
if isinstance(inpt, torch.Tensor) and ( if torch.jit.is_scripting() or is_simple_tensor(inpt):
torch.jit.is_scripting() or not isinstance(inpt, datapoints._datapoint.Datapoint)
):
return crop_image_tensor(inpt, top, left, height, width) return crop_image_tensor(inpt, top, left, height, width)
elif isinstance(inpt, datapoints._datapoint.Datapoint): elif isinstance(inpt, datapoints._datapoint.Datapoint):
return inpt.crop(top, left, height, width) return inpt.crop(top, left, height, width)
...@@ -1476,9 +1464,7 @@ def perspective( ...@@ -1476,9 +1464,7 @@ def perspective(
) -> datapoints.InputTypeJIT: ) -> datapoints.InputTypeJIT:
if not torch.jit.is_scripting(): if not torch.jit.is_scripting():
_log_api_usage_once(perspective) _log_api_usage_once(perspective)
if isinstance(inpt, torch.Tensor) and ( if torch.jit.is_scripting() or is_simple_tensor(inpt):
torch.jit.is_scripting() or not isinstance(inpt, datapoints._datapoint.Datapoint)
):
return perspective_image_tensor( return perspective_image_tensor(
inpt, startpoints, endpoints, interpolation=interpolation, fill=fill, coefficients=coefficients inpt, startpoints, endpoints, interpolation=interpolation, fill=fill, coefficients=coefficients
) )
...@@ -1639,9 +1625,7 @@ def elastic( ...@@ -1639,9 +1625,7 @@ def elastic(
if not torch.jit.is_scripting(): if not torch.jit.is_scripting():
_log_api_usage_once(elastic) _log_api_usage_once(elastic)
if isinstance(inpt, torch.Tensor) and ( if torch.jit.is_scripting() or is_simple_tensor(inpt):
torch.jit.is_scripting() or not isinstance(inpt, datapoints._datapoint.Datapoint)
):
return elastic_image_tensor(inpt, displacement, interpolation=interpolation, fill=fill) return elastic_image_tensor(inpt, displacement, interpolation=interpolation, fill=fill)
elif isinstance(inpt, datapoints._datapoint.Datapoint): elif isinstance(inpt, datapoints._datapoint.Datapoint):
return inpt.elastic(displacement, interpolation=interpolation, fill=fill) return inpt.elastic(displacement, interpolation=interpolation, fill=fill)
...@@ -1754,9 +1738,7 @@ def center_crop(inpt: datapoints.InputTypeJIT, output_size: List[int]) -> datapo ...@@ -1754,9 +1738,7 @@ def center_crop(inpt: datapoints.InputTypeJIT, output_size: List[int]) -> datapo
if not torch.jit.is_scripting(): if not torch.jit.is_scripting():
_log_api_usage_once(center_crop) _log_api_usage_once(center_crop)
if isinstance(inpt, torch.Tensor) and ( if torch.jit.is_scripting() or is_simple_tensor(inpt):
torch.jit.is_scripting() or not isinstance(inpt, datapoints._datapoint.Datapoint)
):
return center_crop_image_tensor(inpt, output_size) return center_crop_image_tensor(inpt, output_size)
elif isinstance(inpt, datapoints._datapoint.Datapoint): elif isinstance(inpt, datapoints._datapoint.Datapoint):
return inpt.center_crop(output_size) return inpt.center_crop(output_size)
...@@ -1850,9 +1832,7 @@ def resized_crop( ...@@ -1850,9 +1832,7 @@ def resized_crop(
if not torch.jit.is_scripting(): if not torch.jit.is_scripting():
_log_api_usage_once(resized_crop) _log_api_usage_once(resized_crop)
if isinstance(inpt, torch.Tensor) and ( if torch.jit.is_scripting() or is_simple_tensor(inpt):
torch.jit.is_scripting() or not isinstance(inpt, datapoints._datapoint.Datapoint)
):
return resized_crop_image_tensor( return resized_crop_image_tensor(
inpt, top, left, height, width, antialias=antialias, size=size, interpolation=interpolation inpt, top, left, height, width, antialias=antialias, size=size, interpolation=interpolation
) )
...@@ -1935,9 +1915,7 @@ def five_crop( ...@@ -1935,9 +1915,7 @@ def five_crop(
# TODO: consider breaking BC here to return List[datapoints.ImageTypeJIT/VideoTypeJIT] to align this op with # TODO: consider breaking BC here to return List[datapoints.ImageTypeJIT/VideoTypeJIT] to align this op with
# `ten_crop` # `ten_crop`
if isinstance(inpt, torch.Tensor) and ( if torch.jit.is_scripting() or is_simple_tensor(inpt):
torch.jit.is_scripting() or not isinstance(inpt, (datapoints.Image, datapoints.Video))
):
return five_crop_image_tensor(inpt, size) return five_crop_image_tensor(inpt, size)
elif isinstance(inpt, datapoints.Image): elif isinstance(inpt, datapoints.Image):
output = five_crop_image_tensor(inpt.as_subclass(torch.Tensor), size) output = five_crop_image_tensor(inpt.as_subclass(torch.Tensor), size)
...@@ -1991,9 +1969,7 @@ def ten_crop( ...@@ -1991,9 +1969,7 @@ def ten_crop(
if not torch.jit.is_scripting(): if not torch.jit.is_scripting():
_log_api_usage_once(ten_crop) _log_api_usage_once(ten_crop)
if isinstance(inpt, torch.Tensor) and ( if torch.jit.is_scripting() or is_simple_tensor(inpt):
torch.jit.is_scripting() or not isinstance(inpt, (datapoints.Image, datapoints.Video))
):
return ten_crop_image_tensor(inpt, size, vertical_flip=vertical_flip) return ten_crop_image_tensor(inpt, size, vertical_flip=vertical_flip)
elif isinstance(inpt, datapoints.Image): elif isinstance(inpt, datapoints.Image):
output = ten_crop_image_tensor(inpt.as_subclass(torch.Tensor), size, vertical_flip=vertical_flip) output = ten_crop_image_tensor(inpt.as_subclass(torch.Tensor), size, vertical_flip=vertical_flip)
......
...@@ -9,6 +9,8 @@ from torchvision.transforms.functional_tensor import _max_value ...@@ -9,6 +9,8 @@ from torchvision.transforms.functional_tensor import _max_value
from torchvision.utils import _log_api_usage_once from torchvision.utils import _log_api_usage_once
from ._utils import is_simple_tensor
def get_dimensions_image_tensor(image: torch.Tensor) -> List[int]: def get_dimensions_image_tensor(image: torch.Tensor) -> List[int]:
chw = list(image.shape[-3:]) chw = list(image.shape[-3:])
...@@ -29,9 +31,7 @@ def get_dimensions(inpt: Union[datapoints.ImageTypeJIT, datapoints.VideoTypeJIT] ...@@ -29,9 +31,7 @@ def get_dimensions(inpt: Union[datapoints.ImageTypeJIT, datapoints.VideoTypeJIT]
if not torch.jit.is_scripting(): if not torch.jit.is_scripting():
_log_api_usage_once(get_dimensions) _log_api_usage_once(get_dimensions)
if isinstance(inpt, torch.Tensor) and ( if torch.jit.is_scripting() or is_simple_tensor(inpt):
torch.jit.is_scripting() or not isinstance(inpt, (datapoints.Image, datapoints.Video))
):
return get_dimensions_image_tensor(inpt) return get_dimensions_image_tensor(inpt)
elif isinstance(inpt, (datapoints.Image, datapoints.Video)): elif isinstance(inpt, (datapoints.Image, datapoints.Video)):
channels = inpt.num_channels channels = inpt.num_channels
...@@ -68,9 +68,7 @@ def get_num_channels(inpt: Union[datapoints.ImageTypeJIT, datapoints.VideoTypeJI ...@@ -68,9 +68,7 @@ def get_num_channels(inpt: Union[datapoints.ImageTypeJIT, datapoints.VideoTypeJI
if not torch.jit.is_scripting(): if not torch.jit.is_scripting():
_log_api_usage_once(get_num_channels) _log_api_usage_once(get_num_channels)
if isinstance(inpt, torch.Tensor) and ( if torch.jit.is_scripting() or is_simple_tensor(inpt):
torch.jit.is_scripting() or not isinstance(inpt, (datapoints.Image, datapoints.Video))
):
return get_num_channels_image_tensor(inpt) return get_num_channels_image_tensor(inpt)
elif isinstance(inpt, (datapoints.Image, datapoints.Video)): elif isinstance(inpt, (datapoints.Image, datapoints.Video)):
return inpt.num_channels return inpt.num_channels
...@@ -120,14 +118,12 @@ def get_spatial_size(inpt: datapoints.InputTypeJIT) -> List[int]: ...@@ -120,14 +118,12 @@ def get_spatial_size(inpt: datapoints.InputTypeJIT) -> List[int]:
if not torch.jit.is_scripting(): if not torch.jit.is_scripting():
_log_api_usage_once(get_spatial_size) _log_api_usage_once(get_spatial_size)
if isinstance(inpt, torch.Tensor) and ( if torch.jit.is_scripting() or is_simple_tensor(inpt):
torch.jit.is_scripting() or not isinstance(inpt, datapoints._datapoint.Datapoint)
):
return get_spatial_size_image_tensor(inpt) return get_spatial_size_image_tensor(inpt)
elif isinstance(inpt, (datapoints.Image, datapoints.Video, datapoints.BoundingBox, datapoints.Mask)): elif isinstance(inpt, (datapoints.Image, datapoints.Video, datapoints.BoundingBox, datapoints.Mask)):
return list(inpt.spatial_size) return list(inpt.spatial_size)
elif isinstance(inpt, PIL.Image.Image): elif isinstance(inpt, PIL.Image.Image):
return get_spatial_size_image_pil(inpt) # type: ignore[no-any-return] return get_spatial_size_image_pil(inpt)
else: else:
raise TypeError( raise TypeError(
f"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, " f"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, "
...@@ -143,7 +139,7 @@ def get_num_frames(inpt: datapoints.VideoTypeJIT) -> int: ...@@ -143,7 +139,7 @@ def get_num_frames(inpt: datapoints.VideoTypeJIT) -> int:
if not torch.jit.is_scripting(): if not torch.jit.is_scripting():
_log_api_usage_once(get_num_frames) _log_api_usage_once(get_num_frames)
if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, datapoints.Video)): if torch.jit.is_scripting() or is_simple_tensor(inpt):
return get_num_frames_video(inpt) return get_num_frames_video(inpt)
elif isinstance(inpt, datapoints.Video): elif isinstance(inpt, datapoints.Video):
return inpt.num_frames return inpt.num_frames
...@@ -336,9 +332,7 @@ def convert_color_space( ...@@ -336,9 +332,7 @@ def convert_color_space(
if not torch.jit.is_scripting(): if not torch.jit.is_scripting():
_log_api_usage_once(convert_color_space) _log_api_usage_once(convert_color_space)
if isinstance(inpt, torch.Tensor) and ( if torch.jit.is_scripting() or is_simple_tensor(inpt):
torch.jit.is_scripting() or not isinstance(inpt, (datapoints.Image, datapoints.Video))
):
if old_color_space is None: if old_color_space is None:
raise RuntimeError( raise RuntimeError(
"In order to convert the color space of simple tensors, " "In order to convert the color space of simple tensors, "
...@@ -443,9 +437,7 @@ def convert_dtype( ...@@ -443,9 +437,7 @@ def convert_dtype(
if not torch.jit.is_scripting(): if not torch.jit.is_scripting():
_log_api_usage_once(convert_dtype) _log_api_usage_once(convert_dtype)
if isinstance(inpt, torch.Tensor) and ( if torch.jit.is_scripting() or is_simple_tensor(inpt):
torch.jit.is_scripting() or not isinstance(inpt, (datapoints.Image, datapoints.Video))
):
return convert_dtype_image_tensor(inpt, dtype) return convert_dtype_image_tensor(inpt, dtype)
elif isinstance(inpt, datapoints.Image): elif isinstance(inpt, datapoints.Image):
output = convert_dtype_image_tensor(inpt.as_subclass(torch.Tensor), dtype) output = convert_dtype_image_tensor(inpt.as_subclass(torch.Tensor), dtype)
......
...@@ -10,7 +10,7 @@ from torchvision.transforms.functional import pil_to_tensor, to_pil_image ...@@ -10,7 +10,7 @@ from torchvision.transforms.functional import pil_to_tensor, to_pil_image
from torchvision.utils import _log_api_usage_once from torchvision.utils import _log_api_usage_once
from ..utils import is_simple_tensor from ._utils import is_simple_tensor
def normalize_image_tensor( def normalize_image_tensor(
...@@ -61,9 +61,9 @@ def normalize( ...@@ -61,9 +61,9 @@ def normalize(
if not torch.jit.is_scripting(): if not torch.jit.is_scripting():
_log_api_usage_once(normalize) _log_api_usage_once(normalize)
if is_simple_tensor(inpt) or isinstance(inpt, (datapoints.Image, datapoints.Video)): if isinstance(inpt, (datapoints.Image, datapoints.Video)):
inpt = inpt.as_subclass(torch.Tensor) inpt = inpt.as_subclass(torch.Tensor)
else: elif not is_simple_tensor(inpt):
raise TypeError( raise TypeError(
f"Input can either be a plain tensor or an `Image` or `Video` datapoint, " f"Input can either be a plain tensor or an `Image` or `Video` datapoint, "
f"but got {type(inpt)} instead." f"but got {type(inpt)} instead."
...@@ -175,9 +175,7 @@ def gaussian_blur( ...@@ -175,9 +175,7 @@ def gaussian_blur(
if not torch.jit.is_scripting(): if not torch.jit.is_scripting():
_log_api_usage_once(gaussian_blur) _log_api_usage_once(gaussian_blur)
if isinstance(inpt, torch.Tensor) and ( if torch.jit.is_scripting() or is_simple_tensor(inpt):
torch.jit.is_scripting() or not isinstance(inpt, datapoints._datapoint.Datapoint)
):
return gaussian_blur_image_tensor(inpt, kernel_size=kernel_size, sigma=sigma) return gaussian_blur_image_tensor(inpt, kernel_size=kernel_size, sigma=sigma)
elif isinstance(inpt, datapoints._datapoint.Datapoint): elif isinstance(inpt, datapoints._datapoint.Datapoint):
return inpt.gaussian_blur(kernel_size=kernel_size, sigma=sigma) return inpt.gaussian_blur(kernel_size=kernel_size, sigma=sigma)
......
...@@ -4,6 +4,8 @@ from torchvision.prototype import datapoints ...@@ -4,6 +4,8 @@ from torchvision.prototype import datapoints
from torchvision.utils import _log_api_usage_once from torchvision.utils import _log_api_usage_once
from ._utils import is_simple_tensor
def uniform_temporal_subsample_video(video: torch.Tensor, num_samples: int, temporal_dim: int = -4) -> torch.Tensor: def uniform_temporal_subsample_video(video: torch.Tensor, num_samples: int, temporal_dim: int = -4) -> torch.Tensor:
# Reference: https://github.com/facebookresearch/pytorchvideo/blob/a0a131e/pytorchvideo/transforms/functional.py#L19 # Reference: https://github.com/facebookresearch/pytorchvideo/blob/a0a131e/pytorchvideo/transforms/functional.py#L19
...@@ -18,7 +20,7 @@ def uniform_temporal_subsample( ...@@ -18,7 +20,7 @@ def uniform_temporal_subsample(
if not torch.jit.is_scripting(): if not torch.jit.is_scripting():
_log_api_usage_once(uniform_temporal_subsample) _log_api_usage_once(uniform_temporal_subsample)
if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, datapoints.Video)): if torch.jit.is_scripting() or is_simple_tensor(inpt):
return uniform_temporal_subsample_video(inpt, num_samples, temporal_dim=temporal_dim) return uniform_temporal_subsample_video(inpt, num_samples, temporal_dim=temporal_dim)
elif isinstance(inpt, datapoints.Video): elif isinstance(inpt, datapoints.Video):
if temporal_dim != -4 and inpt.ndim - 4 != temporal_dim: if temporal_dim != -4 and inpt.ndim - 4 != temporal_dim:
......
from typing import Any
import torch
from torchvision.prototype.datapoints._datapoint import Datapoint
def is_simple_tensor(inpt: Any) -> bool:
return isinstance(inpt, torch.Tensor) and not isinstance(inpt, Datapoint)
...@@ -3,16 +3,10 @@ from __future__ import annotations ...@@ -3,16 +3,10 @@ from __future__ import annotations
from typing import Any, Callable, List, Tuple, Type, Union from typing import Any, Callable, List, Tuple, Type, Union
import PIL.Image import PIL.Image
import torch
from torchvision._utils import sequence_to_str from torchvision._utils import sequence_to_str
from torchvision.prototype import datapoints from torchvision.prototype import datapoints
from torchvision.prototype.datapoints._datapoint import Datapoint from torchvision.prototype.transforms.functional import get_dimensions, get_spatial_size, is_simple_tensor
from torchvision.prototype.transforms.functional import get_dimensions, get_spatial_size
def is_simple_tensor(inpt: Any) -> bool:
return isinstance(inpt, torch.Tensor) and not isinstance(inpt, Datapoint)
def query_bounding_box(flat_inputs: List[Any]) -> datapoints.BoundingBox: def query_bounding_box(flat_inputs: List[Any]) -> datapoints.BoundingBox:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment