Unverified Commit 641fdd9f authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

remove custom types defintions from datapoints module (#7814)

parent 6b020798
from torchvision import _BETA_TRANSFORMS_WARNING, _WARN_ABOUT_BETA_TRANSFORMS from torchvision import _BETA_TRANSFORMS_WARNING, _WARN_ABOUT_BETA_TRANSFORMS
from ._bounding_box import BoundingBoxes, BoundingBoxFormat from ._bounding_box import BoundingBoxes, BoundingBoxFormat
from ._datapoint import _FillType, _FillTypeJIT, _InputType, _InputTypeJIT, Datapoint from ._datapoint import Datapoint
from ._image import _ImageType, _ImageTypeJIT, _TensorImageType, _TensorImageTypeJIT, Image from ._image import Image
from ._mask import Mask from ._mask import Mask
from ._video import _TensorVideoType, _TensorVideoTypeJIT, _VideoType, _VideoTypeJIT, Video from ._video import Video
if _WARN_ABOUT_BETA_TRANSFORMS: if _WARN_ABOUT_BETA_TRANSFORMS:
import warnings import warnings
......
from __future__ import annotations from __future__ import annotations
from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, Tuple, Type, TypeVar, Union from typing import Any, Callable, Dict, Mapping, Optional, Sequence, Tuple, Type, TypeVar, Union
import PIL.Image
import torch import torch
from torch._C import DisableTorchFunctionSubclass from torch._C import DisableTorchFunctionSubclass
from torch.types import _device, _dtype, _size from torch.types import _device, _dtype, _size
D = TypeVar("D", bound="Datapoint") D = TypeVar("D", bound="Datapoint")
_FillType = Union[int, float, Sequence[int], Sequence[float], None]
_FillTypeJIT = Optional[List[float]]
class Datapoint(torch.Tensor): class Datapoint(torch.Tensor):
...@@ -132,7 +129,3 @@ class Datapoint(torch.Tensor): ...@@ -132,7 +129,3 @@ class Datapoint(torch.Tensor):
# `BoundingBoxes.format` and `BoundingBoxes.canvas_size`, which are immutable and thus implicitly deep-copied by # `BoundingBoxes.format` and `BoundingBoxes.canvas_size`, which are immutable and thus implicitly deep-copied by
# `BoundingBoxes.clone()`. # `BoundingBoxes.clone()`.
return self.detach().clone().requires_grad_(self.requires_grad) # type: ignore[return-value] return self.detach().clone().requires_grad_(self.requires_grad) # type: ignore[return-value]
_InputType = Union[torch.Tensor, PIL.Image.Image, Datapoint]
_InputTypeJIT = torch.Tensor
...@@ -45,9 +45,3 @@ class Image(Datapoint): ...@@ -45,9 +45,3 @@ class Image(Datapoint):
def __repr__(self, *, tensor_contents: Any = None) -> str: # type: ignore[override] def __repr__(self, *, tensor_contents: Any = None) -> str: # type: ignore[override]
return self._make_repr() return self._make_repr()
_ImageType = Union[torch.Tensor, PIL.Image.Image, Image]
_ImageTypeJIT = torch.Tensor
_TensorImageType = Union[torch.Tensor, Image]
_TensorImageTypeJIT = torch.Tensor
...@@ -35,9 +35,3 @@ class Video(Datapoint): ...@@ -35,9 +35,3 @@ class Video(Datapoint):
def __repr__(self, *, tensor_contents: Any = None) -> str: # type: ignore[override] def __repr__(self, *, tensor_contents: Any = None) -> str: # type: ignore[override]
return self._make_repr() return self._make_repr()
_VideoType = Union[torch.Tensor, Video]
_VideoTypeJIT = torch.Tensor
_TensorVideoType = Union[torch.Tensor, Video]
_TensorVideoTypeJIT = torch.Tensor
...@@ -26,15 +26,15 @@ class SimpleCopyPaste(Transform): ...@@ -26,15 +26,15 @@ class SimpleCopyPaste(Transform):
def _copy_paste( def _copy_paste(
self, self,
image: datapoints._TensorImageType, image: Union[torch.Tensor, datapoints.Image],
target: Dict[str, Any], target: Dict[str, Any],
paste_image: datapoints._TensorImageType, paste_image: Union[torch.Tensor, datapoints.Image],
paste_target: Dict[str, Any], paste_target: Dict[str, Any],
random_selection: torch.Tensor, random_selection: torch.Tensor,
blending: bool, blending: bool,
resize_interpolation: F.InterpolationMode, resize_interpolation: F.InterpolationMode,
antialias: Optional[bool], antialias: Optional[bool],
) -> Tuple[datapoints._TensorImageType, Dict[str, Any]]: ) -> Tuple[torch.Tensor, Dict[str, Any]]:
paste_masks = paste_target["masks"].wrap_like(paste_target["masks"], paste_target["masks"][random_selection]) paste_masks = paste_target["masks"].wrap_like(paste_target["masks"], paste_target["masks"][random_selection])
paste_boxes = paste_target["boxes"].wrap_like(paste_target["boxes"], paste_target["boxes"][random_selection]) paste_boxes = paste_target["boxes"].wrap_like(paste_target["boxes"], paste_target["boxes"][random_selection])
...@@ -106,7 +106,7 @@ class SimpleCopyPaste(Transform): ...@@ -106,7 +106,7 @@ class SimpleCopyPaste(Transform):
def _extract_image_targets( def _extract_image_targets(
self, flat_sample: List[Any] self, flat_sample: List[Any]
) -> Tuple[List[datapoints._TensorImageType], List[Dict[str, Any]]]: ) -> Tuple[List[Union[torch.Tensor, datapoints.Image]], List[Dict[str, Any]]]:
# fetch all images, bboxes, masks and labels from unstructured input # fetch all images, bboxes, masks and labels from unstructured input
# with List[image], List[BoundingBoxes], List[Mask], List[Label] # with List[image], List[BoundingBoxes], List[Mask], List[Label]
images, bboxes, masks, labels = [], [], [], [] images, bboxes, masks, labels = [], [], [], []
...@@ -137,7 +137,7 @@ class SimpleCopyPaste(Transform): ...@@ -137,7 +137,7 @@ class SimpleCopyPaste(Transform):
def _insert_outputs( def _insert_outputs(
self, self,
flat_sample: List[Any], flat_sample: List[Any],
output_images: List[datapoints._TensorImageType], output_images: List[torch.Tensor],
output_targets: List[Dict[str, Any]], output_targets: List[Dict[str, Any]],
) -> None: ) -> None:
c0, c1, c2, c3 = 0, 0, 0, 0 c0, c1, c2, c3 = 0, 0, 0, 0
......
...@@ -6,7 +6,7 @@ import torch ...@@ -6,7 +6,7 @@ import torch
from torchvision import datapoints from torchvision import datapoints
from torchvision.prototype.datapoints import Label, OneHotLabel from torchvision.prototype.datapoints import Label, OneHotLabel
from torchvision.transforms.v2 import functional as F, Transform from torchvision.transforms.v2 import functional as F, Transform
from torchvision.transforms.v2._utils import _get_fill, _setup_fill_arg, _setup_size from torchvision.transforms.v2._utils import _FillType, _get_fill, _setup_fill_arg, _setup_size
from torchvision.transforms.v2.utils import get_bounding_boxes, has_any, is_simple_tensor, query_size from torchvision.transforms.v2.utils import get_bounding_boxes, has_any, is_simple_tensor, query_size
...@@ -14,7 +14,7 @@ class FixedSizeCrop(Transform): ...@@ -14,7 +14,7 @@ class FixedSizeCrop(Transform):
def __init__( def __init__(
self, self,
size: Union[int, Sequence[int]], size: Union[int, Sequence[int]],
fill: Union[datapoints._FillType, Dict[Union[Type, str], datapoints._FillType]] = 0, fill: Union[_FillType, Dict[Union[Type, str], _FillType]] = 0,
padding_mode: str = "constant", padding_mode: str = "constant",
) -> None: ) -> None:
super().__init__() super().__init__()
......
...@@ -39,9 +39,7 @@ class PermuteDimensions(Transform): ...@@ -39,9 +39,7 @@ class PermuteDimensions(Transform):
) )
self.dims = dims self.dims = dims
def _transform( def _transform(self, inpt: Any, params: Dict[str, Any]) -> torch.Tensor:
self, inpt: Union[datapoints._TensorImageType, datapoints._TensorVideoType], params: Dict[str, Any]
) -> torch.Tensor:
dims = self.dims[type(inpt)] dims = self.dims[type(inpt)]
if dims is None: if dims is None:
return inpt.as_subclass(torch.Tensor) return inpt.as_subclass(torch.Tensor)
...@@ -63,9 +61,7 @@ class TransposeDimensions(Transform): ...@@ -63,9 +61,7 @@ class TransposeDimensions(Transform):
) )
self.dims = dims self.dims = dims
def _transform( def _transform(self, inpt: Any, params: Dict[str, Any]) -> torch.Tensor:
self, inpt: Union[datapoints._TensorImageType, datapoints._TensorVideoType], params: Dict[str, Any]
) -> torch.Tensor:
dims = self.dims[type(inpt)] dims = self.dims[type(inpt)]
if dims is None: if dims is None:
return inpt.as_subclass(torch.Tensor) return inpt.as_subclass(torch.Tensor)
......
...@@ -10,17 +10,21 @@ from torchvision.transforms import _functional_tensor as _FT ...@@ -10,17 +10,21 @@ from torchvision.transforms import _functional_tensor as _FT
from torchvision.transforms.v2 import AutoAugmentPolicy, functional as F, InterpolationMode, Transform from torchvision.transforms.v2 import AutoAugmentPolicy, functional as F, InterpolationMode, Transform
from torchvision.transforms.v2.functional._geometry import _check_interpolation from torchvision.transforms.v2.functional._geometry import _check_interpolation
from torchvision.transforms.v2.functional._meta import get_size from torchvision.transforms.v2.functional._meta import get_size
from torchvision.transforms.v2.functional._utils import _FillType, _FillTypeJIT
from ._utils import _get_fill, _setup_fill_arg from ._utils import _get_fill, _setup_fill_arg
from .utils import check_type, is_simple_tensor from .utils import check_type, is_simple_tensor
ImageOrVideo = Union[torch.Tensor, PIL.Image.Image, datapoints.Image, datapoints.Video]
class _AutoAugmentBase(Transform): class _AutoAugmentBase(Transform):
def __init__( def __init__(
self, self,
*, *,
interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST, interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
fill: Union[datapoints._FillType, Dict[Union[Type, str], datapoints._FillType]] = None, fill: Union[_FillType, Dict[Union[Type, str], _FillType]] = None,
) -> None: ) -> None:
super().__init__() super().__init__()
self.interpolation = _check_interpolation(interpolation) self.interpolation = _check_interpolation(interpolation)
...@@ -35,7 +39,7 @@ class _AutoAugmentBase(Transform): ...@@ -35,7 +39,7 @@ class _AutoAugmentBase(Transform):
self, self,
inputs: Any, inputs: Any,
unsupported_types: Tuple[Type, ...] = (datapoints.BoundingBoxes, datapoints.Mask), unsupported_types: Tuple[Type, ...] = (datapoints.BoundingBoxes, datapoints.Mask),
) -> Tuple[Tuple[List[Any], TreeSpec, int], Union[datapoints._ImageType, datapoints._VideoType]]: ) -> Tuple[Tuple[List[Any], TreeSpec, int], ImageOrVideo]:
flat_inputs, spec = tree_flatten(inputs if len(inputs) > 1 else inputs[0]) flat_inputs, spec = tree_flatten(inputs if len(inputs) > 1 else inputs[0])
needs_transform_list = self._needs_transform_list(flat_inputs) needs_transform_list = self._needs_transform_list(flat_inputs)
...@@ -68,7 +72,7 @@ class _AutoAugmentBase(Transform): ...@@ -68,7 +72,7 @@ class _AutoAugmentBase(Transform):
def _unflatten_and_insert_image_or_video( def _unflatten_and_insert_image_or_video(
self, self,
flat_inputs_with_spec: Tuple[List[Any], TreeSpec, int], flat_inputs_with_spec: Tuple[List[Any], TreeSpec, int],
image_or_video: Union[datapoints._ImageType, datapoints._VideoType], image_or_video: ImageOrVideo,
) -> Any: ) -> Any:
flat_inputs, spec, idx = flat_inputs_with_spec flat_inputs, spec, idx = flat_inputs_with_spec
flat_inputs[idx] = image_or_video flat_inputs[idx] = image_or_video
...@@ -76,12 +80,12 @@ class _AutoAugmentBase(Transform): ...@@ -76,12 +80,12 @@ class _AutoAugmentBase(Transform):
def _apply_image_or_video_transform( def _apply_image_or_video_transform(
self, self,
image: Union[datapoints._ImageType, datapoints._VideoType], image: ImageOrVideo,
transform_id: str, transform_id: str,
magnitude: float, magnitude: float,
interpolation: Union[InterpolationMode, int], interpolation: Union[InterpolationMode, int],
fill: Dict[Union[Type, str], datapoints._FillTypeJIT], fill: Dict[Union[Type, str], _FillTypeJIT],
) -> Union[datapoints._ImageType, datapoints._VideoType]: ) -> ImageOrVideo:
fill_ = _get_fill(fill, type(image)) fill_ = _get_fill(fill, type(image))
if transform_id == "Identity": if transform_id == "Identity":
...@@ -214,7 +218,7 @@ class AutoAugment(_AutoAugmentBase): ...@@ -214,7 +218,7 @@ class AutoAugment(_AutoAugmentBase):
self, self,
policy: AutoAugmentPolicy = AutoAugmentPolicy.IMAGENET, policy: AutoAugmentPolicy = AutoAugmentPolicy.IMAGENET,
interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST, interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
fill: Union[datapoints._FillType, Dict[Union[Type, str], datapoints._FillType]] = None, fill: Union[_FillType, Dict[Union[Type, str], _FillType]] = None,
) -> None: ) -> None:
super().__init__(interpolation=interpolation, fill=fill) super().__init__(interpolation=interpolation, fill=fill)
self.policy = policy self.policy = policy
...@@ -394,7 +398,7 @@ class RandAugment(_AutoAugmentBase): ...@@ -394,7 +398,7 @@ class RandAugment(_AutoAugmentBase):
magnitude: int = 9, magnitude: int = 9,
num_magnitude_bins: int = 31, num_magnitude_bins: int = 31,
interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST, interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
fill: Union[datapoints._FillType, Dict[Union[Type, str], datapoints._FillType]] = None, fill: Union[_FillType, Dict[Union[Type, str], _FillType]] = None,
) -> None: ) -> None:
super().__init__(interpolation=interpolation, fill=fill) super().__init__(interpolation=interpolation, fill=fill)
self.num_ops = num_ops self.num_ops = num_ops
...@@ -467,7 +471,7 @@ class TrivialAugmentWide(_AutoAugmentBase): ...@@ -467,7 +471,7 @@ class TrivialAugmentWide(_AutoAugmentBase):
self, self,
num_magnitude_bins: int = 31, num_magnitude_bins: int = 31,
interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST, interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
fill: Union[datapoints._FillType, Dict[Union[Type, str], datapoints._FillType]] = None, fill: Union[_FillType, Dict[Union[Type, str], _FillType]] = None,
): ):
super().__init__(interpolation=interpolation, fill=fill) super().__init__(interpolation=interpolation, fill=fill)
self.num_magnitude_bins = num_magnitude_bins self.num_magnitude_bins = num_magnitude_bins
...@@ -550,7 +554,7 @@ class AugMix(_AutoAugmentBase): ...@@ -550,7 +554,7 @@ class AugMix(_AutoAugmentBase):
alpha: float = 1.0, alpha: float = 1.0,
all_ops: bool = True, all_ops: bool = True,
interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR, interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
fill: Union[datapoints._FillType, Dict[Union[Type, str], datapoints._FillType]] = None, fill: Union[_FillType, Dict[Union[Type, str], _FillType]] = None,
) -> None: ) -> None:
super().__init__(interpolation=interpolation, fill=fill) super().__init__(interpolation=interpolation, fill=fill)
self._PARAMETER_MAX = 10 self._PARAMETER_MAX = 10
......
...@@ -261,9 +261,7 @@ class RandomPhotometricDistort(Transform): ...@@ -261,9 +261,7 @@ class RandomPhotometricDistort(Transform):
params["channel_permutation"] = torch.randperm(num_channels) if torch.rand(1) < self.p else None params["channel_permutation"] = torch.randperm(num_channels) if torch.rand(1) < self.p else None
return params return params
def _transform( def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
self, inpt: Union[datapoints._ImageType, datapoints._VideoType], params: Dict[str, Any]
) -> Union[datapoints._ImageType, datapoints._VideoType]:
if params["brightness_factor"] is not None: if params["brightness_factor"] is not None:
inpt = F.adjust_brightness(inpt, brightness_factor=params["brightness_factor"]) inpt = F.adjust_brightness(inpt, brightness_factor=params["brightness_factor"])
if params["contrast_factor"] is not None and params["contrast_before"]: if params["contrast_factor"] is not None and params["contrast_before"]:
......
...@@ -11,6 +11,7 @@ from torchvision.ops.boxes import box_iou ...@@ -11,6 +11,7 @@ from torchvision.ops.boxes import box_iou
from torchvision.transforms.functional import _get_perspective_coeffs from torchvision.transforms.functional import _get_perspective_coeffs
from torchvision.transforms.v2 import functional as F, InterpolationMode, Transform from torchvision.transforms.v2 import functional as F, InterpolationMode, Transform
from torchvision.transforms.v2.functional._geometry import _check_interpolation from torchvision.transforms.v2.functional._geometry import _check_interpolation
from torchvision.transforms.v2.functional._utils import _FillType
from ._transform import _RandomApplyTransform from ._transform import _RandomApplyTransform
from ._utils import ( from ._utils import (
...@@ -311,9 +312,6 @@ class RandomResizedCrop(Transform): ...@@ -311,9 +312,6 @@ class RandomResizedCrop(Transform):
) )
ImageOrVideoTypeJIT = Union[datapoints._ImageTypeJIT, datapoints._VideoTypeJIT]
class FiveCrop(Transform): class FiveCrop(Transform):
"""[BETA] Crop the image or video into four corners and the central crop. """[BETA] Crop the image or video into four corners and the central crop.
...@@ -459,7 +457,7 @@ class Pad(Transform): ...@@ -459,7 +457,7 @@ class Pad(Transform):
def __init__( def __init__(
self, self,
padding: Union[int, Sequence[int]], padding: Union[int, Sequence[int]],
fill: Union[datapoints._FillType, Dict[Union[Type, str], datapoints._FillType]] = 0, fill: Union[_FillType, Dict[Union[Type, str], _FillType]] = 0,
padding_mode: Literal["constant", "edge", "reflect", "symmetric"] = "constant", padding_mode: Literal["constant", "edge", "reflect", "symmetric"] = "constant",
) -> None: ) -> None:
super().__init__() super().__init__()
...@@ -514,7 +512,7 @@ class RandomZoomOut(_RandomApplyTransform): ...@@ -514,7 +512,7 @@ class RandomZoomOut(_RandomApplyTransform):
def __init__( def __init__(
self, self,
fill: Union[datapoints._FillType, Dict[Union[Type, str], datapoints._FillType]] = 0, fill: Union[_FillType, Dict[Union[Type, str], _FillType]] = 0,
side_range: Sequence[float] = (1.0, 4.0), side_range: Sequence[float] = (1.0, 4.0),
p: float = 0.5, p: float = 0.5,
) -> None: ) -> None:
...@@ -592,7 +590,7 @@ class RandomRotation(Transform): ...@@ -592,7 +590,7 @@ class RandomRotation(Transform):
interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST, interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
expand: bool = False, expand: bool = False,
center: Optional[List[float]] = None, center: Optional[List[float]] = None,
fill: Union[datapoints._FillType, Dict[Union[Type, str], datapoints._FillType]] = 0, fill: Union[_FillType, Dict[Union[Type, str], _FillType]] = 0,
) -> None: ) -> None:
super().__init__() super().__init__()
self.degrees = _setup_angle(degrees, name="degrees", req_sizes=(2,)) self.degrees = _setup_angle(degrees, name="degrees", req_sizes=(2,))
...@@ -674,7 +672,7 @@ class RandomAffine(Transform): ...@@ -674,7 +672,7 @@ class RandomAffine(Transform):
scale: Optional[Sequence[float]] = None, scale: Optional[Sequence[float]] = None,
shear: Optional[Union[int, float, Sequence[float]]] = None, shear: Optional[Union[int, float, Sequence[float]]] = None,
interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST, interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
fill: Union[datapoints._FillType, Dict[Union[Type, str], datapoints._FillType]] = 0, fill: Union[_FillType, Dict[Union[Type, str], _FillType]] = 0,
center: Optional[List[float]] = None, center: Optional[List[float]] = None,
) -> None: ) -> None:
super().__init__() super().__init__()
...@@ -812,7 +810,7 @@ class RandomCrop(Transform): ...@@ -812,7 +810,7 @@ class RandomCrop(Transform):
size: Union[int, Sequence[int]], size: Union[int, Sequence[int]],
padding: Optional[Union[int, Sequence[int]]] = None, padding: Optional[Union[int, Sequence[int]]] = None,
pad_if_needed: bool = False, pad_if_needed: bool = False,
fill: Union[datapoints._FillType, Dict[Union[Type, str], datapoints._FillType]] = 0, fill: Union[_FillType, Dict[Union[Type, str], _FillType]] = 0,
padding_mode: Literal["constant", "edge", "reflect", "symmetric"] = "constant", padding_mode: Literal["constant", "edge", "reflect", "symmetric"] = "constant",
) -> None: ) -> None:
super().__init__() super().__init__()
...@@ -931,7 +929,7 @@ class RandomPerspective(_RandomApplyTransform): ...@@ -931,7 +929,7 @@ class RandomPerspective(_RandomApplyTransform):
distortion_scale: float = 0.5, distortion_scale: float = 0.5,
p: float = 0.5, p: float = 0.5,
interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR, interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
fill: Union[datapoints._FillType, Dict[Union[Type, str], datapoints._FillType]] = 0, fill: Union[_FillType, Dict[Union[Type, str], _FillType]] = 0,
) -> None: ) -> None:
super().__init__(p=p) super().__init__(p=p)
...@@ -1033,7 +1031,7 @@ class ElasticTransform(Transform): ...@@ -1033,7 +1031,7 @@ class ElasticTransform(Transform):
alpha: Union[float, Sequence[float]] = 50.0, alpha: Union[float, Sequence[float]] = 50.0,
sigma: Union[float, Sequence[float]] = 5.0, sigma: Union[float, Sequence[float]] = 5.0,
interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR, interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
fill: Union[datapoints._FillType, Dict[Union[Type, str], datapoints._FillType]] = 0, fill: Union[_FillType, Dict[Union[Type, str], _FillType]] = 0,
) -> None: ) -> None:
super().__init__() super().__init__()
self.alpha = _setup_float_or_seq(alpha, "alpha", 2) self.alpha = _setup_float_or_seq(alpha, "alpha", 2)
......
...@@ -169,9 +169,7 @@ class Normalize(Transform): ...@@ -169,9 +169,7 @@ class Normalize(Transform):
if has_any(sample, PIL.Image.Image): if has_any(sample, PIL.Image.Image):
raise TypeError(f"{type(self).__name__}() does not support PIL images.") raise TypeError(f"{type(self).__name__}() does not support PIL images.")
def _transform( def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
self, inpt: Union[datapoints._TensorImageType, datapoints._TensorVideoType], params: Dict[str, Any]
) -> Any:
return F.normalize(inpt, mean=self.mean, std=self.std, inplace=self.inplace) return F.normalize(inpt, mean=self.mean, std=self.std, inplace=self.inplace)
......
from typing import Any, Dict from typing import Any, Dict
import torch import torch
from torchvision import datapoints
from torchvision.transforms.v2 import functional as F, Transform from torchvision.transforms.v2 import functional as F, Transform
...@@ -25,5 +24,5 @@ class UniformTemporalSubsample(Transform): ...@@ -25,5 +24,5 @@ class UniformTemporalSubsample(Transform):
super().__init__() super().__init__()
self.num_samples = num_samples self.num_samples = num_samples
def _transform(self, inpt: datapoints._VideoType, params: Dict[str, Any]) -> datapoints._VideoType: def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
return F.uniform_temporal_subsample(inpt, self.num_samples) return F.uniform_temporal_subsample(inpt, self.num_samples)
...@@ -5,9 +5,8 @@ from typing import Any, Callable, Dict, Literal, Optional, Sequence, Type, Union ...@@ -5,9 +5,8 @@ from typing import Any, Callable, Dict, Literal, Optional, Sequence, Type, Union
import torch import torch
from torchvision import datapoints
from torchvision.datapoints._datapoint import _FillType, _FillTypeJIT
from torchvision.transforms.transforms import _check_sequence_input, _setup_angle, _setup_size # noqa: F401 from torchvision.transforms.transforms import _check_sequence_input, _setup_angle, _setup_size # noqa: F401
from torchvision.transforms.v2.functional._utils import _FillType, _FillTypeJIT
def _setup_float_or_seq(arg: Union[float, Sequence[float]], name: str, req_size: int = 2) -> Sequence[float]: def _setup_float_or_seq(arg: Union[float, Sequence[float]], name: str, req_size: int = 2) -> Sequence[float]:
...@@ -36,7 +35,7 @@ def _check_fill_arg(fill: Union[_FillType, Dict[Union[Type, str], _FillType]]) - ...@@ -36,7 +35,7 @@ def _check_fill_arg(fill: Union[_FillType, Dict[Union[Type, str], _FillType]]) -
raise TypeError("Got inappropriate fill arg, only Numbers, tuples, lists and dicts are allowed.") raise TypeError("Got inappropriate fill arg, only Numbers, tuples, lists and dicts are allowed.")
def _convert_fill_arg(fill: datapoints._FillType) -> datapoints._FillTypeJIT: def _convert_fill_arg(fill: _FillType) -> _FillTypeJIT:
# Fill = 0 is not equivalent to None, https://github.com/pytorch/vision/issues/6517 # Fill = 0 is not equivalent to None, https://github.com/pytorch/vision/issues/6517
# So, we can't reassign fill to 0 # So, we can't reassign fill to 0
# if fill is None: # if fill is None:
......
from typing import Union
import PIL.Image import PIL.Image
import torch import torch
...@@ -12,14 +10,14 @@ from ._utils import _get_kernel, _register_explicit_noop, _register_kernel_inter ...@@ -12,14 +10,14 @@ from ._utils import _get_kernel, _register_explicit_noop, _register_kernel_inter
@_register_explicit_noop(datapoints.Mask, datapoints.BoundingBoxes, warn_passthrough=True) @_register_explicit_noop(datapoints.Mask, datapoints.BoundingBoxes, warn_passthrough=True)
def erase( def erase(
inpt: Union[datapoints._ImageTypeJIT, datapoints._VideoTypeJIT], inpt: torch.Tensor,
i: int, i: int,
j: int, j: int,
h: int, h: int,
w: int, w: int,
v: torch.Tensor, v: torch.Tensor,
inplace: bool = False, inplace: bool = False,
) -> Union[datapoints._ImageTypeJIT, datapoints._VideoTypeJIT]: ) -> torch.Tensor:
if torch.jit.is_scripting(): if torch.jit.is_scripting():
return erase_image_tensor(inpt, i=i, j=j, h=h, w=w, v=v, inplace=inplace) return erase_image_tensor(inpt, i=i, j=j, h=h, w=w, v=v, inplace=inplace)
......
from typing import List, Union from typing import List
import PIL.Image import PIL.Image
import torch import torch
...@@ -16,9 +16,7 @@ from ._utils import _get_kernel, _register_explicit_noop, _register_kernel_inter ...@@ -16,9 +16,7 @@ from ._utils import _get_kernel, _register_explicit_noop, _register_kernel_inter
@_register_explicit_noop(datapoints.BoundingBoxes, datapoints.Mask, datapoints.Video) @_register_explicit_noop(datapoints.BoundingBoxes, datapoints.Mask, datapoints.Video)
def rgb_to_grayscale( def rgb_to_grayscale(inpt: torch.Tensor, num_output_channels: int = 1) -> torch.Tensor:
inpt: Union[datapoints._ImageTypeJIT, datapoints._VideoTypeJIT], num_output_channels: int = 1
) -> Union[datapoints._ImageTypeJIT, datapoints._VideoTypeJIT]:
if torch.jit.is_scripting(): if torch.jit.is_scripting():
return rgb_to_grayscale_image_tensor(inpt, num_output_channels=num_output_channels) return rgb_to_grayscale_image_tensor(inpt, num_output_channels=num_output_channels)
...@@ -73,7 +71,7 @@ def _blend(image1: torch.Tensor, image2: torch.Tensor, ratio: float) -> torch.Te ...@@ -73,7 +71,7 @@ def _blend(image1: torch.Tensor, image2: torch.Tensor, ratio: float) -> torch.Te
@_register_explicit_noop(datapoints.BoundingBoxes, datapoints.Mask) @_register_explicit_noop(datapoints.BoundingBoxes, datapoints.Mask)
def adjust_brightness(inpt: datapoints._InputTypeJIT, brightness_factor: float) -> datapoints._InputTypeJIT: def adjust_brightness(inpt: torch.Tensor, brightness_factor: float) -> torch.Tensor:
if torch.jit.is_scripting(): if torch.jit.is_scripting():
return adjust_brightness_image_tensor(inpt, brightness_factor=brightness_factor) return adjust_brightness_image_tensor(inpt, brightness_factor=brightness_factor)
...@@ -110,7 +108,7 @@ def adjust_brightness_video(video: torch.Tensor, brightness_factor: float) -> to ...@@ -110,7 +108,7 @@ def adjust_brightness_video(video: torch.Tensor, brightness_factor: float) -> to
@_register_explicit_noop(datapoints.BoundingBoxes, datapoints.Mask) @_register_explicit_noop(datapoints.BoundingBoxes, datapoints.Mask)
def adjust_saturation(inpt: datapoints._InputTypeJIT, saturation_factor: float) -> datapoints._InputTypeJIT: def adjust_saturation(inpt: torch.Tensor, saturation_factor: float) -> torch.Tensor:
if torch.jit.is_scripting(): if torch.jit.is_scripting():
return adjust_saturation_image_tensor(inpt, saturation_factor=saturation_factor) return adjust_saturation_image_tensor(inpt, saturation_factor=saturation_factor)
...@@ -149,7 +147,7 @@ def adjust_saturation_video(video: torch.Tensor, saturation_factor: float) -> to ...@@ -149,7 +147,7 @@ def adjust_saturation_video(video: torch.Tensor, saturation_factor: float) -> to
@_register_explicit_noop(datapoints.BoundingBoxes, datapoints.Mask) @_register_explicit_noop(datapoints.BoundingBoxes, datapoints.Mask)
def adjust_contrast(inpt: datapoints._InputTypeJIT, contrast_factor: float) -> datapoints._InputTypeJIT: def adjust_contrast(inpt: torch.Tensor, contrast_factor: float) -> torch.Tensor:
if torch.jit.is_scripting(): if torch.jit.is_scripting():
return adjust_contrast_image_tensor(inpt, contrast_factor=contrast_factor) return adjust_contrast_image_tensor(inpt, contrast_factor=contrast_factor)
...@@ -188,7 +186,7 @@ def adjust_contrast_video(video: torch.Tensor, contrast_factor: float) -> torch. ...@@ -188,7 +186,7 @@ def adjust_contrast_video(video: torch.Tensor, contrast_factor: float) -> torch.
@_register_explicit_noop(datapoints.BoundingBoxes, datapoints.Mask) @_register_explicit_noop(datapoints.BoundingBoxes, datapoints.Mask)
def adjust_sharpness(inpt: datapoints._InputTypeJIT, sharpness_factor: float) -> datapoints._InputTypeJIT: def adjust_sharpness(inpt: torch.Tensor, sharpness_factor: float) -> torch.Tensor:
if torch.jit.is_scripting(): if torch.jit.is_scripting():
return adjust_sharpness_image_tensor(inpt, sharpness_factor=sharpness_factor) return adjust_sharpness_image_tensor(inpt, sharpness_factor=sharpness_factor)
...@@ -261,7 +259,7 @@ def adjust_sharpness_video(video: torch.Tensor, sharpness_factor: float) -> torc ...@@ -261,7 +259,7 @@ def adjust_sharpness_video(video: torch.Tensor, sharpness_factor: float) -> torc
@_register_explicit_noop(datapoints.BoundingBoxes, datapoints.Mask) @_register_explicit_noop(datapoints.BoundingBoxes, datapoints.Mask)
def adjust_hue(inpt: datapoints._InputTypeJIT, hue_factor: float) -> datapoints._InputTypeJIT: def adjust_hue(inpt: torch.Tensor, hue_factor: float) -> torch.Tensor:
if torch.jit.is_scripting(): if torch.jit.is_scripting():
return adjust_hue_image_tensor(inpt, hue_factor=hue_factor) return adjust_hue_image_tensor(inpt, hue_factor=hue_factor)
...@@ -373,7 +371,7 @@ def adjust_hue_video(video: torch.Tensor, hue_factor: float) -> torch.Tensor: ...@@ -373,7 +371,7 @@ def adjust_hue_video(video: torch.Tensor, hue_factor: float) -> torch.Tensor:
@_register_explicit_noop(datapoints.BoundingBoxes, datapoints.Mask) @_register_explicit_noop(datapoints.BoundingBoxes, datapoints.Mask)
def adjust_gamma(inpt: datapoints._InputTypeJIT, gamma: float, gain: float = 1) -> datapoints._InputTypeJIT: def adjust_gamma(inpt: torch.Tensor, gamma: float, gain: float = 1) -> torch.Tensor:
if torch.jit.is_scripting(): if torch.jit.is_scripting():
return adjust_gamma_image_tensor(inpt, gamma=gamma, gain=gain) return adjust_gamma_image_tensor(inpt, gamma=gamma, gain=gain)
...@@ -413,7 +411,7 @@ def adjust_gamma_video(video: torch.Tensor, gamma: float, gain: float = 1) -> to ...@@ -413,7 +411,7 @@ def adjust_gamma_video(video: torch.Tensor, gamma: float, gain: float = 1) -> to
@_register_explicit_noop(datapoints.BoundingBoxes, datapoints.Mask) @_register_explicit_noop(datapoints.BoundingBoxes, datapoints.Mask)
def posterize(inpt: datapoints._InputTypeJIT, bits: int) -> datapoints._InputTypeJIT: def posterize(inpt: torch.Tensor, bits: int) -> torch.Tensor:
if torch.jit.is_scripting(): if torch.jit.is_scripting():
return posterize_image_tensor(inpt, bits=bits) return posterize_image_tensor(inpt, bits=bits)
...@@ -447,7 +445,7 @@ def posterize_video(video: torch.Tensor, bits: int) -> torch.Tensor: ...@@ -447,7 +445,7 @@ def posterize_video(video: torch.Tensor, bits: int) -> torch.Tensor:
@_register_explicit_noop(datapoints.BoundingBoxes, datapoints.Mask) @_register_explicit_noop(datapoints.BoundingBoxes, datapoints.Mask)
def solarize(inpt: datapoints._InputTypeJIT, threshold: float) -> datapoints._InputTypeJIT: def solarize(inpt: torch.Tensor, threshold: float) -> torch.Tensor:
if torch.jit.is_scripting(): if torch.jit.is_scripting():
return solarize_image_tensor(inpt, threshold=threshold) return solarize_image_tensor(inpt, threshold=threshold)
...@@ -475,7 +473,7 @@ def solarize_video(video: torch.Tensor, threshold: float) -> torch.Tensor: ...@@ -475,7 +473,7 @@ def solarize_video(video: torch.Tensor, threshold: float) -> torch.Tensor:
@_register_explicit_noop(datapoints.BoundingBoxes, datapoints.Mask) @_register_explicit_noop(datapoints.BoundingBoxes, datapoints.Mask)
def autocontrast(inpt: datapoints._InputTypeJIT) -> datapoints._InputTypeJIT: def autocontrast(inpt: torch.Tensor) -> torch.Tensor:
if torch.jit.is_scripting(): if torch.jit.is_scripting():
return autocontrast_image_tensor(inpt) return autocontrast_image_tensor(inpt)
...@@ -525,7 +523,7 @@ def autocontrast_video(video: torch.Tensor) -> torch.Tensor: ...@@ -525,7 +523,7 @@ def autocontrast_video(video: torch.Tensor) -> torch.Tensor:
@_register_explicit_noop(datapoints.BoundingBoxes, datapoints.Mask) @_register_explicit_noop(datapoints.BoundingBoxes, datapoints.Mask)
def equalize(inpt: datapoints._InputTypeJIT) -> datapoints._InputTypeJIT: def equalize(inpt: torch.Tensor) -> torch.Tensor:
if torch.jit.is_scripting(): if torch.jit.is_scripting():
return equalize_image_tensor(inpt) return equalize_image_tensor(inpt)
...@@ -615,7 +613,7 @@ def equalize_video(video: torch.Tensor) -> torch.Tensor: ...@@ -615,7 +613,7 @@ def equalize_video(video: torch.Tensor) -> torch.Tensor:
@_register_explicit_noop(datapoints.BoundingBoxes, datapoints.Mask) @_register_explicit_noop(datapoints.BoundingBoxes, datapoints.Mask)
def invert(inpt: datapoints._InputTypeJIT) -> datapoints._InputTypeJIT: def invert(inpt: torch.Tensor) -> torch.Tensor:
if torch.jit.is_scripting(): if torch.jit.is_scripting():
return invert_image_tensor(inpt) return invert_image_tensor(inpt)
...@@ -646,7 +644,7 @@ def invert_video(video: torch.Tensor) -> torch.Tensor: ...@@ -646,7 +644,7 @@ def invert_video(video: torch.Tensor) -> torch.Tensor:
@_register_explicit_noop(datapoints.BoundingBoxes, datapoints.Mask) @_register_explicit_noop(datapoints.BoundingBoxes, datapoints.Mask)
def permute_channels(inpt: datapoints._InputTypeJIT, permutation: List[int]) -> datapoints._InputTypeJIT: def permute_channels(inpt: torch.Tensor, permutation: List[int]) -> torch.Tensor:
"""Permute the channels of the input according to the given permutation. """Permute the channels of the input according to the given permutation.
This function supports plain :class:`~torch.Tensor`'s, :class:`PIL.Image.Image`'s, and This function supports plain :class:`~torch.Tensor`'s, :class:`PIL.Image.Image`'s, and
......
import warnings import warnings
from typing import Any, List, Union from typing import Any, List
import torch import torch
from torchvision import datapoints
from torchvision.transforms import functional as _F from torchvision.transforms import functional as _F
...@@ -16,7 +15,7 @@ def to_tensor(inpt: Any) -> torch.Tensor: ...@@ -16,7 +15,7 @@ def to_tensor(inpt: Any) -> torch.Tensor:
return _F.to_tensor(inpt) return _F.to_tensor(inpt)
def get_image_size(inpt: Union[datapoints._ImageTypeJIT, datapoints._VideoTypeJIT]) -> List[int]: def get_image_size(inpt: torch.Tensor) -> List[int]:
warnings.warn( warnings.warn(
"The function `get_image_size(...)` is deprecated and will be removed in a future release. " "The function `get_image_size(...)` is deprecated and will be removed in a future release. "
"Instead, please use `get_size(...)` which returns `[h, w]` instead of `[w, h]`." "Instead, please use `get_size(...)` which returns `[h, w]` instead of `[w, h]`."
......
...@@ -25,7 +25,13 @@ from torchvision.utils import _log_api_usage_once ...@@ -25,7 +25,13 @@ from torchvision.utils import _log_api_usage_once
from ._meta import clamp_bounding_boxes, convert_format_bounding_boxes, get_size_image_pil from ._meta import clamp_bounding_boxes, convert_format_bounding_boxes, get_size_image_pil
from ._utils import _get_kernel, _register_explicit_noop, _register_five_ten_crop_kernel, _register_kernel_internal from ._utils import (
_FillTypeJIT,
_get_kernel,
_register_explicit_noop,
_register_five_ten_crop_kernel,
_register_kernel_internal,
)
def _check_interpolation(interpolation: Union[InterpolationMode, int]) -> InterpolationMode: def _check_interpolation(interpolation: Union[InterpolationMode, int]) -> InterpolationMode:
...@@ -39,7 +45,7 @@ def _check_interpolation(interpolation: Union[InterpolationMode, int]) -> Interp ...@@ -39,7 +45,7 @@ def _check_interpolation(interpolation: Union[InterpolationMode, int]) -> Interp
return interpolation return interpolation
def horizontal_flip(inpt: datapoints._InputTypeJIT) -> datapoints._InputTypeJIT: def horizontal_flip(inpt: torch.Tensor) -> torch.Tensor:
if torch.jit.is_scripting(): if torch.jit.is_scripting():
return horizontal_flip_image_tensor(inpt) return horizontal_flip_image_tensor(inpt)
...@@ -95,7 +101,7 @@ def horizontal_flip_video(video: torch.Tensor) -> torch.Tensor: ...@@ -95,7 +101,7 @@ def horizontal_flip_video(video: torch.Tensor) -> torch.Tensor:
return horizontal_flip_image_tensor(video) return horizontal_flip_image_tensor(video)
def vertical_flip(inpt: datapoints._InputTypeJIT) -> datapoints._InputTypeJIT: def vertical_flip(inpt: torch.Tensor) -> torch.Tensor:
if torch.jit.is_scripting(): if torch.jit.is_scripting():
return vertical_flip_image_tensor(inpt) return vertical_flip_image_tensor(inpt)
...@@ -171,12 +177,12 @@ def _compute_resized_output_size( ...@@ -171,12 +177,12 @@ def _compute_resized_output_size(
def resize( def resize(
inpt: datapoints._InputTypeJIT, inpt: torch.Tensor,
size: List[int], size: List[int],
interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR, interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
max_size: Optional[int] = None, max_size: Optional[int] = None,
antialias: Optional[Union[str, bool]] = "warn", antialias: Optional[Union[str, bool]] = "warn",
) -> datapoints._InputTypeJIT: ) -> torch.Tensor:
if torch.jit.is_scripting(): if torch.jit.is_scripting():
return resize_image_tensor(inpt, size=size, interpolation=interpolation, max_size=max_size, antialias=antialias) return resize_image_tensor(inpt, size=size, interpolation=interpolation, max_size=max_size, antialias=antialias)
...@@ -364,15 +370,15 @@ def resize_video( ...@@ -364,15 +370,15 @@ def resize_video(
def affine( def affine(
inpt: datapoints._InputTypeJIT, inpt: torch.Tensor,
angle: Union[int, float], angle: Union[int, float],
translate: List[float], translate: List[float],
scale: float, scale: float,
shear: List[float], shear: List[float],
interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST, interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
fill: datapoints._FillTypeJIT = None, fill: _FillTypeJIT = None,
center: Optional[List[float]] = None, center: Optional[List[float]] = None,
) -> datapoints._InputTypeJIT: ) -> torch.Tensor:
if torch.jit.is_scripting(): if torch.jit.is_scripting():
return affine_image_tensor( return affine_image_tensor(
inpt, inpt,
...@@ -549,9 +555,7 @@ def _compute_affine_output_size(matrix: List[float], w: int, h: int) -> Tuple[in ...@@ -549,9 +555,7 @@ def _compute_affine_output_size(matrix: List[float], w: int, h: int) -> Tuple[in
return int(size[0]), int(size[1]) # w, h return int(size[0]), int(size[1]) # w, h
def _apply_grid_transform( def _apply_grid_transform(img: torch.Tensor, grid: torch.Tensor, mode: str, fill: _FillTypeJIT) -> torch.Tensor:
img: torch.Tensor, grid: torch.Tensor, mode: str, fill: datapoints._FillTypeJIT
) -> torch.Tensor:
# We are using context knowledge that grid should have float dtype # We are using context knowledge that grid should have float dtype
fp = img.dtype == grid.dtype fp = img.dtype == grid.dtype
...@@ -592,7 +596,7 @@ def _assert_grid_transform_inputs( ...@@ -592,7 +596,7 @@ def _assert_grid_transform_inputs(
image: torch.Tensor, image: torch.Tensor,
matrix: Optional[List[float]], matrix: Optional[List[float]],
interpolation: str, interpolation: str,
fill: datapoints._FillTypeJIT, fill: _FillTypeJIT,
supported_interpolation_modes: List[str], supported_interpolation_modes: List[str],
coeffs: Optional[List[float]] = None, coeffs: Optional[List[float]] = None,
) -> None: ) -> None:
...@@ -657,7 +661,7 @@ def affine_image_tensor( ...@@ -657,7 +661,7 @@ def affine_image_tensor(
scale: float, scale: float,
shear: List[float], shear: List[float],
interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST, interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
fill: datapoints._FillTypeJIT = None, fill: _FillTypeJIT = None,
center: Optional[List[float]] = None, center: Optional[List[float]] = None,
) -> torch.Tensor: ) -> torch.Tensor:
interpolation = _check_interpolation(interpolation) interpolation = _check_interpolation(interpolation)
...@@ -709,7 +713,7 @@ def affine_image_pil( ...@@ -709,7 +713,7 @@ def affine_image_pil(
scale: float, scale: float,
shear: List[float], shear: List[float],
interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST, interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
fill: datapoints._FillTypeJIT = None, fill: _FillTypeJIT = None,
center: Optional[List[float]] = None, center: Optional[List[float]] = None,
) -> PIL.Image.Image: ) -> PIL.Image.Image:
interpolation = _check_interpolation(interpolation) interpolation = _check_interpolation(interpolation)
...@@ -868,7 +872,7 @@ def affine_mask( ...@@ -868,7 +872,7 @@ def affine_mask(
translate: List[float], translate: List[float],
scale: float, scale: float,
shear: List[float], shear: List[float],
fill: datapoints._FillTypeJIT = None, fill: _FillTypeJIT = None,
center: Optional[List[float]] = None, center: Optional[List[float]] = None,
) -> torch.Tensor: ) -> torch.Tensor:
if mask.ndim < 3: if mask.ndim < 3:
...@@ -901,7 +905,7 @@ def _affine_mask_dispatch( ...@@ -901,7 +905,7 @@ def _affine_mask_dispatch(
translate: List[float], translate: List[float],
scale: float, scale: float,
shear: List[float], shear: List[float],
fill: datapoints._FillTypeJIT = None, fill: _FillTypeJIT = None,
center: Optional[List[float]] = None, center: Optional[List[float]] = None,
**kwargs, **kwargs,
) -> datapoints.Mask: ) -> datapoints.Mask:
...@@ -925,7 +929,7 @@ def affine_video( ...@@ -925,7 +929,7 @@ def affine_video(
scale: float, scale: float,
shear: List[float], shear: List[float],
interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST, interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
fill: datapoints._FillTypeJIT = None, fill: _FillTypeJIT = None,
center: Optional[List[float]] = None, center: Optional[List[float]] = None,
) -> torch.Tensor: ) -> torch.Tensor:
return affine_image_tensor( return affine_image_tensor(
...@@ -941,13 +945,13 @@ def affine_video( ...@@ -941,13 +945,13 @@ def affine_video(
def rotate( def rotate(
inpt: datapoints._InputTypeJIT, inpt: torch.Tensor,
angle: float, angle: float,
interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST, interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
expand: bool = False, expand: bool = False,
center: Optional[List[float]] = None, center: Optional[List[float]] = None,
fill: datapoints._FillTypeJIT = None, fill: _FillTypeJIT = None,
) -> datapoints._InputTypeJIT: ) -> torch.Tensor:
if torch.jit.is_scripting(): if torch.jit.is_scripting():
return rotate_image_tensor( return rotate_image_tensor(
inpt, angle=angle, interpolation=interpolation, expand=expand, fill=fill, center=center inpt, angle=angle, interpolation=interpolation, expand=expand, fill=fill, center=center
...@@ -967,7 +971,7 @@ def rotate_image_tensor( ...@@ -967,7 +971,7 @@ def rotate_image_tensor(
interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST, interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
expand: bool = False, expand: bool = False,
center: Optional[List[float]] = None, center: Optional[List[float]] = None,
fill: datapoints._FillTypeJIT = None, fill: _FillTypeJIT = None,
) -> torch.Tensor: ) -> torch.Tensor:
interpolation = _check_interpolation(interpolation) interpolation = _check_interpolation(interpolation)
...@@ -1012,7 +1016,7 @@ def rotate_image_pil( ...@@ -1012,7 +1016,7 @@ def rotate_image_pil(
interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST, interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
expand: bool = False, expand: bool = False,
center: Optional[List[float]] = None, center: Optional[List[float]] = None,
fill: datapoints._FillTypeJIT = None, fill: _FillTypeJIT = None,
) -> PIL.Image.Image: ) -> PIL.Image.Image:
interpolation = _check_interpolation(interpolation) interpolation = _check_interpolation(interpolation)
...@@ -1068,7 +1072,7 @@ def rotate_mask( ...@@ -1068,7 +1072,7 @@ def rotate_mask(
angle: float, angle: float,
expand: bool = False, expand: bool = False,
center: Optional[List[float]] = None, center: Optional[List[float]] = None,
fill: datapoints._FillTypeJIT = None, fill: _FillTypeJIT = None,
) -> torch.Tensor: ) -> torch.Tensor:
if mask.ndim < 3: if mask.ndim < 3:
mask = mask.unsqueeze(0) mask = mask.unsqueeze(0)
...@@ -1097,7 +1101,7 @@ def _rotate_mask_dispatch( ...@@ -1097,7 +1101,7 @@ def _rotate_mask_dispatch(
angle: float, angle: float,
expand: bool = False, expand: bool = False,
center: Optional[List[float]] = None, center: Optional[List[float]] = None,
fill: datapoints._FillTypeJIT = None, fill: _FillTypeJIT = None,
**kwargs, **kwargs,
) -> datapoints.Mask: ) -> datapoints.Mask:
output = rotate_mask(inpt.as_subclass(torch.Tensor), angle=angle, expand=expand, fill=fill, center=center) output = rotate_mask(inpt.as_subclass(torch.Tensor), angle=angle, expand=expand, fill=fill, center=center)
...@@ -1111,17 +1115,17 @@ def rotate_video( ...@@ -1111,17 +1115,17 @@ def rotate_video(
interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST, interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
expand: bool = False, expand: bool = False,
center: Optional[List[float]] = None, center: Optional[List[float]] = None,
fill: datapoints._FillTypeJIT = None, fill: _FillTypeJIT = None,
) -> torch.Tensor: ) -> torch.Tensor:
return rotate_image_tensor(video, angle, interpolation=interpolation, expand=expand, fill=fill, center=center) return rotate_image_tensor(video, angle, interpolation=interpolation, expand=expand, fill=fill, center=center)
def pad( def pad(
inpt: datapoints._InputTypeJIT, inpt: torch.Tensor,
padding: List[int], padding: List[int],
fill: Optional[Union[int, float, List[float]]] = None, fill: Optional[Union[int, float, List[float]]] = None,
padding_mode: str = "constant", padding_mode: str = "constant",
) -> datapoints._InputTypeJIT: ) -> torch.Tensor:
if torch.jit.is_scripting(): if torch.jit.is_scripting():
return pad_image_tensor(inpt, padding=padding, fill=fill, padding_mode=padding_mode) return pad_image_tensor(inpt, padding=padding, fill=fill, padding_mode=padding_mode)
...@@ -1336,7 +1340,7 @@ def pad_video( ...@@ -1336,7 +1340,7 @@ def pad_video(
return pad_image_tensor(video, padding, fill=fill, padding_mode=padding_mode) return pad_image_tensor(video, padding, fill=fill, padding_mode=padding_mode)
def crop(inpt: datapoints._InputTypeJIT, top: int, left: int, height: int, width: int) -> datapoints._InputTypeJIT: def crop(inpt: torch.Tensor, top: int, left: int, height: int, width: int) -> torch.Tensor:
if torch.jit.is_scripting(): if torch.jit.is_scripting():
return crop_image_tensor(inpt, top=top, left=left, height=height, width=width) return crop_image_tensor(inpt, top=top, left=left, height=height, width=width)
...@@ -1423,13 +1427,13 @@ def crop_video(video: torch.Tensor, top: int, left: int, height: int, width: int ...@@ -1423,13 +1427,13 @@ def crop_video(video: torch.Tensor, top: int, left: int, height: int, width: int
def perspective( def perspective(
inpt: datapoints._InputTypeJIT, inpt: torch.Tensor,
startpoints: Optional[List[List[int]]], startpoints: Optional[List[List[int]]],
endpoints: Optional[List[List[int]]], endpoints: Optional[List[List[int]]],
interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR, interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
fill: datapoints._FillTypeJIT = None, fill: _FillTypeJIT = None,
coefficients: Optional[List[float]] = None, coefficients: Optional[List[float]] = None,
) -> datapoints._InputTypeJIT: ) -> torch.Tensor:
if torch.jit.is_scripting(): if torch.jit.is_scripting():
return perspective_image_tensor( return perspective_image_tensor(
inpt, inpt,
...@@ -1507,7 +1511,7 @@ def perspective_image_tensor( ...@@ -1507,7 +1511,7 @@ def perspective_image_tensor(
startpoints: Optional[List[List[int]]], startpoints: Optional[List[List[int]]],
endpoints: Optional[List[List[int]]], endpoints: Optional[List[List[int]]],
interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR, interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
fill: datapoints._FillTypeJIT = None, fill: _FillTypeJIT = None,
coefficients: Optional[List[float]] = None, coefficients: Optional[List[float]] = None,
) -> torch.Tensor: ) -> torch.Tensor:
perspective_coeffs = _perspective_coefficients(startpoints, endpoints, coefficients) perspective_coeffs = _perspective_coefficients(startpoints, endpoints, coefficients)
...@@ -1554,7 +1558,7 @@ def perspective_image_pil( ...@@ -1554,7 +1558,7 @@ def perspective_image_pil(
startpoints: Optional[List[List[int]]], startpoints: Optional[List[List[int]]],
endpoints: Optional[List[List[int]]], endpoints: Optional[List[List[int]]],
interpolation: Union[InterpolationMode, int] = InterpolationMode.BICUBIC, interpolation: Union[InterpolationMode, int] = InterpolationMode.BICUBIC,
fill: datapoints._FillTypeJIT = None, fill: _FillTypeJIT = None,
coefficients: Optional[List[float]] = None, coefficients: Optional[List[float]] = None,
) -> PIL.Image.Image: ) -> PIL.Image.Image:
perspective_coeffs = _perspective_coefficients(startpoints, endpoints, coefficients) perspective_coeffs = _perspective_coefficients(startpoints, endpoints, coefficients)
...@@ -1679,7 +1683,7 @@ def perspective_mask( ...@@ -1679,7 +1683,7 @@ def perspective_mask(
mask: torch.Tensor, mask: torch.Tensor,
startpoints: Optional[List[List[int]]], startpoints: Optional[List[List[int]]],
endpoints: Optional[List[List[int]]], endpoints: Optional[List[List[int]]],
fill: datapoints._FillTypeJIT = None, fill: _FillTypeJIT = None,
coefficients: Optional[List[float]] = None, coefficients: Optional[List[float]] = None,
) -> torch.Tensor: ) -> torch.Tensor:
if mask.ndim < 3: if mask.ndim < 3:
...@@ -1703,7 +1707,7 @@ def _perspective_mask_dispatch( ...@@ -1703,7 +1707,7 @@ def _perspective_mask_dispatch(
inpt: datapoints.Mask, inpt: datapoints.Mask,
startpoints: Optional[List[List[int]]], startpoints: Optional[List[List[int]]],
endpoints: Optional[List[List[int]]], endpoints: Optional[List[List[int]]],
fill: datapoints._FillTypeJIT = None, fill: _FillTypeJIT = None,
coefficients: Optional[List[float]] = None, coefficients: Optional[List[float]] = None,
**kwargs, **kwargs,
) -> datapoints.Mask: ) -> datapoints.Mask:
...@@ -1723,7 +1727,7 @@ def perspective_video( ...@@ -1723,7 +1727,7 @@ def perspective_video(
startpoints: Optional[List[List[int]]], startpoints: Optional[List[List[int]]],
endpoints: Optional[List[List[int]]], endpoints: Optional[List[List[int]]],
interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR, interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
fill: datapoints._FillTypeJIT = None, fill: _FillTypeJIT = None,
coefficients: Optional[List[float]] = None, coefficients: Optional[List[float]] = None,
) -> torch.Tensor: ) -> torch.Tensor:
return perspective_image_tensor( return perspective_image_tensor(
...@@ -1732,11 +1736,11 @@ def perspective_video( ...@@ -1732,11 +1736,11 @@ def perspective_video(
def elastic( def elastic(
inpt: datapoints._InputTypeJIT, inpt: torch.Tensor,
displacement: torch.Tensor, displacement: torch.Tensor,
interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR, interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
fill: datapoints._FillTypeJIT = None, fill: _FillTypeJIT = None,
) -> datapoints._InputTypeJIT: ) -> torch.Tensor:
if torch.jit.is_scripting(): if torch.jit.is_scripting():
return elastic_image_tensor(inpt, displacement=displacement, interpolation=interpolation, fill=fill) return elastic_image_tensor(inpt, displacement=displacement, interpolation=interpolation, fill=fill)
...@@ -1755,7 +1759,7 @@ def elastic_image_tensor( ...@@ -1755,7 +1759,7 @@ def elastic_image_tensor(
image: torch.Tensor, image: torch.Tensor,
displacement: torch.Tensor, displacement: torch.Tensor,
interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR, interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
fill: datapoints._FillTypeJIT = None, fill: _FillTypeJIT = None,
) -> torch.Tensor: ) -> torch.Tensor:
interpolation = _check_interpolation(interpolation) interpolation = _check_interpolation(interpolation)
...@@ -1812,7 +1816,7 @@ def elastic_image_pil( ...@@ -1812,7 +1816,7 @@ def elastic_image_pil(
image: PIL.Image.Image, image: PIL.Image.Image,
displacement: torch.Tensor, displacement: torch.Tensor,
interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR, interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
fill: datapoints._FillTypeJIT = None, fill: _FillTypeJIT = None,
) -> PIL.Image.Image: ) -> PIL.Image.Image:
t_img = pil_to_tensor(image) t_img = pil_to_tensor(image)
output = elastic_image_tensor(t_img, displacement, interpolation=interpolation, fill=fill) output = elastic_image_tensor(t_img, displacement, interpolation=interpolation, fill=fill)
...@@ -1895,7 +1899,7 @@ def _elastic_bounding_boxes_dispatch( ...@@ -1895,7 +1899,7 @@ def _elastic_bounding_boxes_dispatch(
def elastic_mask( def elastic_mask(
mask: torch.Tensor, mask: torch.Tensor,
displacement: torch.Tensor, displacement: torch.Tensor,
fill: datapoints._FillTypeJIT = None, fill: _FillTypeJIT = None,
) -> torch.Tensor: ) -> torch.Tensor:
if mask.ndim < 3: if mask.ndim < 3:
mask = mask.unsqueeze(0) mask = mask.unsqueeze(0)
...@@ -1913,7 +1917,7 @@ def elastic_mask( ...@@ -1913,7 +1917,7 @@ def elastic_mask(
@_register_kernel_internal(elastic, datapoints.Mask, datapoint_wrapper=False) @_register_kernel_internal(elastic, datapoints.Mask, datapoint_wrapper=False)
def _elastic_mask_dispatch( def _elastic_mask_dispatch(
inpt: datapoints.Mask, displacement: torch.Tensor, fill: datapoints._FillTypeJIT = None, **kwargs inpt: datapoints.Mask, displacement: torch.Tensor, fill: _FillTypeJIT = None, **kwargs
) -> datapoints.Mask: ) -> datapoints.Mask:
output = elastic_mask(inpt.as_subclass(torch.Tensor), displacement=displacement, fill=fill) output = elastic_mask(inpt.as_subclass(torch.Tensor), displacement=displacement, fill=fill)
return datapoints.Mask.wrap_like(inpt, output) return datapoints.Mask.wrap_like(inpt, output)
...@@ -1924,12 +1928,12 @@ def elastic_video( ...@@ -1924,12 +1928,12 @@ def elastic_video(
video: torch.Tensor, video: torch.Tensor,
displacement: torch.Tensor, displacement: torch.Tensor,
interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR, interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
fill: datapoints._FillTypeJIT = None, fill: _FillTypeJIT = None,
) -> torch.Tensor: ) -> torch.Tensor:
return elastic_image_tensor(video, displacement, interpolation=interpolation, fill=fill) return elastic_image_tensor(video, displacement, interpolation=interpolation, fill=fill)
def center_crop(inpt: datapoints._InputTypeJIT, output_size: List[int]) -> datapoints._InputTypeJIT: def center_crop(inpt: torch.Tensor, output_size: List[int]) -> torch.Tensor:
if torch.jit.is_scripting(): if torch.jit.is_scripting():
return center_crop_image_tensor(inpt, output_size=output_size) return center_crop_image_tensor(inpt, output_size=output_size)
...@@ -2049,7 +2053,7 @@ def center_crop_video(video: torch.Tensor, output_size: List[int]) -> torch.Tens ...@@ -2049,7 +2053,7 @@ def center_crop_video(video: torch.Tensor, output_size: List[int]) -> torch.Tens
def resized_crop( def resized_crop(
inpt: datapoints._InputTypeJIT, inpt: torch.Tensor,
top: int, top: int,
left: int, left: int,
height: int, height: int,
...@@ -2057,7 +2061,7 @@ def resized_crop( ...@@ -2057,7 +2061,7 @@ def resized_crop(
size: List[int], size: List[int],
interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR, interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
antialias: Optional[Union[str, bool]] = "warn", antialias: Optional[Union[str, bool]] = "warn",
) -> datapoints._InputTypeJIT: ) -> torch.Tensor:
if torch.jit.is_scripting(): if torch.jit.is_scripting():
return resized_crop_image_tensor( return resized_crop_image_tensor(
inpt, inpt,
...@@ -2201,14 +2205,8 @@ def resized_crop_video( ...@@ -2201,14 +2205,8 @@ def resized_crop_video(
@_register_explicit_noop(datapoints.BoundingBoxes, datapoints.Mask, warn_passthrough=True) @_register_explicit_noop(datapoints.BoundingBoxes, datapoints.Mask, warn_passthrough=True)
def five_crop( def five_crop(
inpt: datapoints._InputTypeJIT, size: List[int] inpt: torch.Tensor, size: List[int]
) -> Tuple[ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
datapoints._InputTypeJIT,
datapoints._InputTypeJIT,
datapoints._InputTypeJIT,
datapoints._InputTypeJIT,
datapoints._InputTypeJIT,
]:
if torch.jit.is_scripting(): if torch.jit.is_scripting():
return five_crop_image_tensor(inpt, size=size) return five_crop_image_tensor(inpt, size=size)
...@@ -2280,18 +2278,18 @@ def five_crop_video( ...@@ -2280,18 +2278,18 @@ def five_crop_video(
@_register_explicit_noop(datapoints.BoundingBoxes, datapoints.Mask, warn_passthrough=True) @_register_explicit_noop(datapoints.BoundingBoxes, datapoints.Mask, warn_passthrough=True)
def ten_crop( def ten_crop(
inpt: Union[datapoints._ImageTypeJIT, datapoints._VideoTypeJIT], size: List[int], vertical_flip: bool = False inpt: torch.Tensor, size: List[int], vertical_flip: bool = False
) -> Tuple[ ) -> Tuple[
datapoints._InputTypeJIT, torch.Tensor,
datapoints._InputTypeJIT, torch.Tensor,
datapoints._InputTypeJIT, torch.Tensor,
datapoints._InputTypeJIT, torch.Tensor,
datapoints._InputTypeJIT, torch.Tensor,
datapoints._InputTypeJIT, torch.Tensor,
datapoints._InputTypeJIT, torch.Tensor,
datapoints._InputTypeJIT, torch.Tensor,
datapoints._InputTypeJIT, torch.Tensor,
datapoints._InputTypeJIT, torch.Tensor,
]: ]:
if torch.jit.is_scripting(): if torch.jit.is_scripting():
return ten_crop_image_tensor(inpt, size=size, vertical_flip=vertical_flip) return ten_crop_image_tensor(inpt, size=size, vertical_flip=vertical_flip)
......
from typing import List, Optional, Tuple, Union from typing import List, Optional, Tuple
import PIL.Image import PIL.Image
import torch import torch
...@@ -12,7 +12,7 @@ from ._utils import _get_kernel, _register_kernel_internal, _register_unsupporte ...@@ -12,7 +12,7 @@ from ._utils import _get_kernel, _register_kernel_internal, _register_unsupporte
@_register_unsupported_type(datapoints.BoundingBoxes, datapoints.Mask) @_register_unsupported_type(datapoints.BoundingBoxes, datapoints.Mask)
def get_dimensions(inpt: Union[datapoints._ImageTypeJIT, datapoints._VideoTypeJIT]) -> List[int]: def get_dimensions(inpt: torch.Tensor) -> List[int]:
if torch.jit.is_scripting(): if torch.jit.is_scripting():
return get_dimensions_image_tensor(inpt) return get_dimensions_image_tensor(inpt)
...@@ -45,7 +45,7 @@ def get_dimensions_video(video: torch.Tensor) -> List[int]: ...@@ -45,7 +45,7 @@ def get_dimensions_video(video: torch.Tensor) -> List[int]:
@_register_unsupported_type(datapoints.BoundingBoxes, datapoints.Mask) @_register_unsupported_type(datapoints.BoundingBoxes, datapoints.Mask)
def get_num_channels(inpt: Union[datapoints._ImageTypeJIT, datapoints._VideoTypeJIT]) -> int: def get_num_channels(inpt: torch.Tensor) -> int:
if torch.jit.is_scripting(): if torch.jit.is_scripting():
return get_num_channels_image_tensor(inpt) return get_num_channels_image_tensor(inpt)
...@@ -81,7 +81,7 @@ def get_num_channels_video(video: torch.Tensor) -> int: ...@@ -81,7 +81,7 @@ def get_num_channels_video(video: torch.Tensor) -> int:
get_image_num_channels = get_num_channels get_image_num_channels = get_num_channels
def get_size(inpt: datapoints._InputTypeJIT) -> List[int]: def get_size(inpt: torch.Tensor) -> List[int]:
if torch.jit.is_scripting(): if torch.jit.is_scripting():
return get_size_image_tensor(inpt) return get_size_image_tensor(inpt)
...@@ -124,7 +124,7 @@ def get_size_bounding_boxes(bounding_box: datapoints.BoundingBoxes) -> List[int] ...@@ -124,7 +124,7 @@ def get_size_bounding_boxes(bounding_box: datapoints.BoundingBoxes) -> List[int]
@_register_unsupported_type(PIL.Image.Image, datapoints.Image, datapoints.BoundingBoxes, datapoints.Mask) @_register_unsupported_type(PIL.Image.Image, datapoints.Image, datapoints.BoundingBoxes, datapoints.Mask)
def get_num_frames(inpt: datapoints._VideoTypeJIT) -> int: def get_num_frames(inpt: torch.Tensor) -> int:
if torch.jit.is_scripting(): if torch.jit.is_scripting():
return get_num_frames_video(inpt) return get_num_frames_video(inpt)
...@@ -201,11 +201,11 @@ def _convert_format_bounding_boxes( ...@@ -201,11 +201,11 @@ def _convert_format_bounding_boxes(
def convert_format_bounding_boxes( def convert_format_bounding_boxes(
inpt: datapoints._InputTypeJIT, inpt: torch.Tensor,
old_format: Optional[BoundingBoxFormat] = None, old_format: Optional[BoundingBoxFormat] = None,
new_format: Optional[BoundingBoxFormat] = None, new_format: Optional[BoundingBoxFormat] = None,
inplace: bool = False, inplace: bool = False,
) -> datapoints._InputTypeJIT: ) -> torch.Tensor:
# This being a kernel / dispatcher hybrid, we need an option to pass `old_format` explicitly for simple tensor # This being a kernel / dispatcher hybrid, we need an option to pass `old_format` explicitly for simple tensor
# inputs as well as extract it from `datapoints.BoundingBoxes` inputs. However, putting a default value on # inputs as well as extract it from `datapoints.BoundingBoxes` inputs. However, putting a default value on
# `old_format` means we also need to put one on `new_format` to have syntactically correct Python. Here we mimic the # `old_format` means we also need to put one on `new_format` to have syntactically correct Python. Here we mimic the
...@@ -252,10 +252,10 @@ def _clamp_bounding_boxes( ...@@ -252,10 +252,10 @@ def _clamp_bounding_boxes(
def clamp_bounding_boxes( def clamp_bounding_boxes(
inpt: datapoints._InputTypeJIT, inpt: torch.Tensor,
format: Optional[BoundingBoxFormat] = None, format: Optional[BoundingBoxFormat] = None,
canvas_size: Optional[Tuple[int, int]] = None, canvas_size: Optional[Tuple[int, int]] = None,
) -> datapoints._InputTypeJIT: ) -> torch.Tensor:
if not torch.jit.is_scripting(): if not torch.jit.is_scripting():
_log_api_usage_once(clamp_bounding_boxes) _log_api_usage_once(clamp_bounding_boxes)
......
import math import math
from typing import List, Optional, Union from typing import List, Optional
import PIL.Image import PIL.Image
import torch import torch
...@@ -17,7 +17,7 @@ from ._utils import _get_kernel, _register_explicit_noop, _register_kernel_inter ...@@ -17,7 +17,7 @@ from ._utils import _get_kernel, _register_explicit_noop, _register_kernel_inter
@_register_explicit_noop(datapoints.BoundingBoxes, datapoints.Mask) @_register_explicit_noop(datapoints.BoundingBoxes, datapoints.Mask)
@_register_unsupported_type(PIL.Image.Image) @_register_unsupported_type(PIL.Image.Image)
def normalize( def normalize(
inpt: Union[datapoints._TensorImageTypeJIT, datapoints._TensorVideoTypeJIT], inpt: torch.Tensor,
mean: List[float], mean: List[float],
std: List[float], std: List[float],
inplace: bool = False, inplace: bool = False,
...@@ -74,9 +74,7 @@ def normalize_video(video: torch.Tensor, mean: List[float], std: List[float], in ...@@ -74,9 +74,7 @@ def normalize_video(video: torch.Tensor, mean: List[float], std: List[float], in
@_register_explicit_noop(datapoints.BoundingBoxes, datapoints.Mask) @_register_explicit_noop(datapoints.BoundingBoxes, datapoints.Mask)
def gaussian_blur( def gaussian_blur(inpt: torch.Tensor, kernel_size: List[int], sigma: Optional[List[float]] = None) -> torch.Tensor:
inpt: datapoints._InputTypeJIT, kernel_size: List[int], sigma: Optional[List[float]] = None
) -> datapoints._InputTypeJIT:
if torch.jit.is_scripting(): if torch.jit.is_scripting():
return gaussian_blur_image_tensor(inpt, kernel_size=kernel_size, sigma=sigma) return gaussian_blur_image_tensor(inpt, kernel_size=kernel_size, sigma=sigma)
...@@ -185,9 +183,7 @@ def gaussian_blur_video( ...@@ -185,9 +183,7 @@ def gaussian_blur_video(
@_register_unsupported_type(PIL.Image.Image) @_register_unsupported_type(PIL.Image.Image)
def to_dtype( def to_dtype(inpt: torch.Tensor, dtype: torch.dtype = torch.float, scale: bool = False) -> torch.Tensor:
inpt: datapoints._InputTypeJIT, dtype: torch.dtype = torch.float, scale: bool = False
) -> datapoints._InputTypeJIT:
if torch.jit.is_scripting(): if torch.jit.is_scripting():
return to_dtype_image_tensor(inpt, dtype=dtype, scale=scale) return to_dtype_image_tensor(inpt, dtype=dtype, scale=scale)
...@@ -278,8 +274,6 @@ def to_dtype_video(video: torch.Tensor, dtype: torch.dtype = torch.float, scale: ...@@ -278,8 +274,6 @@ def to_dtype_video(video: torch.Tensor, dtype: torch.dtype = torch.float, scale:
@_register_kernel_internal(to_dtype, datapoints.BoundingBoxes, datapoint_wrapper=False) @_register_kernel_internal(to_dtype, datapoints.BoundingBoxes, datapoint_wrapper=False)
@_register_kernel_internal(to_dtype, datapoints.Mask, datapoint_wrapper=False) @_register_kernel_internal(to_dtype, datapoints.Mask, datapoint_wrapper=False)
def _to_dtype_tensor_dispatch( def _to_dtype_tensor_dispatch(inpt: torch.Tensor, dtype: torch.dtype, scale: bool = False) -> torch.Tensor:
inpt: datapoints._InputTypeJIT, dtype: torch.dtype, scale: bool = False
) -> datapoints._InputTypeJIT:
# We don't need to unwrap and rewrap here, since Datapoint.to() preserves the type # We don't need to unwrap and rewrap here, since Datapoint.to() preserves the type
return inpt.to(dtype) return inpt.to(dtype)
...@@ -11,7 +11,7 @@ from ._utils import _get_kernel, _register_explicit_noop, _register_kernel_inter ...@@ -11,7 +11,7 @@ from ._utils import _get_kernel, _register_explicit_noop, _register_kernel_inter
@_register_explicit_noop( @_register_explicit_noop(
PIL.Image.Image, datapoints.Image, datapoints.BoundingBoxes, datapoints.Mask, warn_passthrough=True PIL.Image.Image, datapoints.Image, datapoints.BoundingBoxes, datapoints.Mask, warn_passthrough=True
) )
def uniform_temporal_subsample(inpt: datapoints._VideoTypeJIT, num_samples: int) -> datapoints._VideoTypeJIT: def uniform_temporal_subsample(inpt: torch.Tensor, num_samples: int) -> torch.Tensor:
if torch.jit.is_scripting(): if torch.jit.is_scripting():
return uniform_temporal_subsample_video(inpt, num_samples=num_samples) return uniform_temporal_subsample_video(inpt, num_samples=num_samples)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment