Unverified Commit d4d20f01 authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

make type alias private (#7266)

parent e405f3c3
from ._bounding_box import BoundingBox, BoundingBoxFormat from ._bounding_box import BoundingBox, BoundingBoxFormat
from ._datapoint import FillType, FillTypeJIT, InputType, InputTypeJIT from ._datapoint import _FillType, _FillTypeJIT, _InputType, _InputTypeJIT
from ._image import Image, ImageType, ImageTypeJIT, TensorImageType, TensorImageTypeJIT from ._image import _ImageType, _ImageTypeJIT, _TensorImageType, _TensorImageTypeJIT, Image
from ._mask import Mask from ._mask import Mask
from ._video import TensorVideoType, TensorVideoTypeJIT, Video, VideoType, VideoTypeJIT from ._video import _TensorVideoType, _TensorVideoTypeJIT, _VideoType, _VideoTypeJIT, Video
from ._dataset_wrapper import wrap_dataset_for_transforms_v2 # type: ignore[attr-defined] # usort: skip from ._dataset_wrapper import wrap_dataset_for_transforms_v2 # type: ignore[attr-defined] # usort: skip
...@@ -6,7 +6,7 @@ import torch ...@@ -6,7 +6,7 @@ import torch
from torchvision._utils import StrEnum from torchvision._utils import StrEnum
from torchvision.transforms import InterpolationMode # TODO: this needs to be moved out of transforms from torchvision.transforms import InterpolationMode # TODO: this needs to be moved out of transforms
from ._datapoint import Datapoint, FillTypeJIT from ._datapoint import _FillTypeJIT, Datapoint
class BoundingBoxFormat(StrEnum): class BoundingBoxFormat(StrEnum):
...@@ -136,7 +136,7 @@ class BoundingBox(Datapoint): ...@@ -136,7 +136,7 @@ class BoundingBox(Datapoint):
interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST, interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
expand: bool = False, expand: bool = False,
center: Optional[List[float]] = None, center: Optional[List[float]] = None,
fill: FillTypeJIT = None, fill: _FillTypeJIT = None,
) -> BoundingBox: ) -> BoundingBox:
output, spatial_size = self._F.rotate_bounding_box( output, spatial_size = self._F.rotate_bounding_box(
self.as_subclass(torch.Tensor), self.as_subclass(torch.Tensor),
...@@ -155,7 +155,7 @@ class BoundingBox(Datapoint): ...@@ -155,7 +155,7 @@ class BoundingBox(Datapoint):
scale: float, scale: float,
shear: List[float], shear: List[float],
interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST, interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
fill: FillTypeJIT = None, fill: _FillTypeJIT = None,
center: Optional[List[float]] = None, center: Optional[List[float]] = None,
) -> BoundingBox: ) -> BoundingBox:
output = self._F.affine_bounding_box( output = self._F.affine_bounding_box(
...@@ -175,7 +175,7 @@ class BoundingBox(Datapoint): ...@@ -175,7 +175,7 @@ class BoundingBox(Datapoint):
startpoints: Optional[List[List[int]]], startpoints: Optional[List[List[int]]],
endpoints: Optional[List[List[int]]], endpoints: Optional[List[List[int]]],
interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR, interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
fill: FillTypeJIT = None, fill: _FillTypeJIT = None,
coefficients: Optional[List[float]] = None, coefficients: Optional[List[float]] = None,
) -> BoundingBox: ) -> BoundingBox:
output = self._F.perspective_bounding_box( output = self._F.perspective_bounding_box(
...@@ -192,7 +192,7 @@ class BoundingBox(Datapoint): ...@@ -192,7 +192,7 @@ class BoundingBox(Datapoint):
self, self,
displacement: torch.Tensor, displacement: torch.Tensor,
interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR, interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
fill: FillTypeJIT = None, fill: _FillTypeJIT = None,
) -> BoundingBox: ) -> BoundingBox:
output = self._F.elastic_bounding_box( output = self._F.elastic_bounding_box(
self.as_subclass(torch.Tensor), self.format, self.spatial_size, displacement=displacement self.as_subclass(torch.Tensor), self.format, self.spatial_size, displacement=displacement
......
...@@ -11,8 +11,8 @@ from torchvision.transforms import InterpolationMode ...@@ -11,8 +11,8 @@ from torchvision.transforms import InterpolationMode
D = TypeVar("D", bound="Datapoint") D = TypeVar("D", bound="Datapoint")
FillType = Union[int, float, Sequence[int], Sequence[float], None] _FillType = Union[int, float, Sequence[int], Sequence[float], None]
FillTypeJIT = Optional[List[float]] _FillTypeJIT = Optional[List[float]]
class Datapoint(torch.Tensor): class Datapoint(torch.Tensor):
...@@ -181,7 +181,7 @@ class Datapoint(torch.Tensor): ...@@ -181,7 +181,7 @@ class Datapoint(torch.Tensor):
interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST, interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
expand: bool = False, expand: bool = False,
center: Optional[List[float]] = None, center: Optional[List[float]] = None,
fill: FillTypeJIT = None, fill: _FillTypeJIT = None,
) -> Datapoint: ) -> Datapoint:
return self return self
...@@ -192,7 +192,7 @@ class Datapoint(torch.Tensor): ...@@ -192,7 +192,7 @@ class Datapoint(torch.Tensor):
scale: float, scale: float,
shear: List[float], shear: List[float],
interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST, interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
fill: FillTypeJIT = None, fill: _FillTypeJIT = None,
center: Optional[List[float]] = None, center: Optional[List[float]] = None,
) -> Datapoint: ) -> Datapoint:
return self return self
...@@ -202,7 +202,7 @@ class Datapoint(torch.Tensor): ...@@ -202,7 +202,7 @@ class Datapoint(torch.Tensor):
startpoints: Optional[List[List[int]]], startpoints: Optional[List[List[int]]],
endpoints: Optional[List[List[int]]], endpoints: Optional[List[List[int]]],
interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR, interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
fill: FillTypeJIT = None, fill: _FillTypeJIT = None,
coefficients: Optional[List[float]] = None, coefficients: Optional[List[float]] = None,
) -> Datapoint: ) -> Datapoint:
return self return self
...@@ -211,7 +211,7 @@ class Datapoint(torch.Tensor): ...@@ -211,7 +211,7 @@ class Datapoint(torch.Tensor):
self, self,
displacement: torch.Tensor, displacement: torch.Tensor,
interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR, interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
fill: FillTypeJIT = None, fill: _FillTypeJIT = None,
) -> Datapoint: ) -> Datapoint:
return self return self
...@@ -255,5 +255,5 @@ class Datapoint(torch.Tensor): ...@@ -255,5 +255,5 @@ class Datapoint(torch.Tensor):
return self return self
InputType = Union[torch.Tensor, PIL.Image.Image, Datapoint] _InputType = Union[torch.Tensor, PIL.Image.Image, Datapoint]
InputTypeJIT = torch.Tensor _InputTypeJIT = torch.Tensor
...@@ -6,7 +6,7 @@ import PIL.Image ...@@ -6,7 +6,7 @@ import PIL.Image
import torch import torch
from torchvision.transforms.functional import InterpolationMode from torchvision.transforms.functional import InterpolationMode
from ._datapoint import Datapoint, FillTypeJIT from ._datapoint import _FillTypeJIT, Datapoint
class Image(Datapoint): class Image(Datapoint):
...@@ -116,7 +116,7 @@ class Image(Datapoint): ...@@ -116,7 +116,7 @@ class Image(Datapoint):
interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST, interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
expand: bool = False, expand: bool = False,
center: Optional[List[float]] = None, center: Optional[List[float]] = None,
fill: FillTypeJIT = None, fill: _FillTypeJIT = None,
) -> Image: ) -> Image:
output = self._F.rotate_image_tensor( output = self._F.rotate_image_tensor(
self.as_subclass(torch.Tensor), angle, interpolation=interpolation, expand=expand, fill=fill, center=center self.as_subclass(torch.Tensor), angle, interpolation=interpolation, expand=expand, fill=fill, center=center
...@@ -130,7 +130,7 @@ class Image(Datapoint): ...@@ -130,7 +130,7 @@ class Image(Datapoint):
scale: float, scale: float,
shear: List[float], shear: List[float],
interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST, interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
fill: FillTypeJIT = None, fill: _FillTypeJIT = None,
center: Optional[List[float]] = None, center: Optional[List[float]] = None,
) -> Image: ) -> Image:
output = self._F.affine_image_tensor( output = self._F.affine_image_tensor(
...@@ -150,7 +150,7 @@ class Image(Datapoint): ...@@ -150,7 +150,7 @@ class Image(Datapoint):
startpoints: Optional[List[List[int]]], startpoints: Optional[List[List[int]]],
endpoints: Optional[List[List[int]]], endpoints: Optional[List[List[int]]],
interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR, interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
fill: FillTypeJIT = None, fill: _FillTypeJIT = None,
coefficients: Optional[List[float]] = None, coefficients: Optional[List[float]] = None,
) -> Image: ) -> Image:
output = self._F.perspective_image_tensor( output = self._F.perspective_image_tensor(
...@@ -167,7 +167,7 @@ class Image(Datapoint): ...@@ -167,7 +167,7 @@ class Image(Datapoint):
self, self,
displacement: torch.Tensor, displacement: torch.Tensor,
interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR, interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
fill: FillTypeJIT = None, fill: _FillTypeJIT = None,
) -> Image: ) -> Image:
output = self._F.elastic_image_tensor( output = self._F.elastic_image_tensor(
self.as_subclass(torch.Tensor), displacement, interpolation=interpolation, fill=fill self.as_subclass(torch.Tensor), displacement, interpolation=interpolation, fill=fill
...@@ -241,7 +241,7 @@ class Image(Datapoint): ...@@ -241,7 +241,7 @@ class Image(Datapoint):
return Image.wrap_like(self, output) return Image.wrap_like(self, output)
ImageType = Union[torch.Tensor, PIL.Image.Image, Image] _ImageType = Union[torch.Tensor, PIL.Image.Image, Image]
ImageTypeJIT = torch.Tensor _ImageTypeJIT = torch.Tensor
TensorImageType = Union[torch.Tensor, Image] _TensorImageType = Union[torch.Tensor, Image]
TensorImageTypeJIT = torch.Tensor _TensorImageTypeJIT = torch.Tensor
...@@ -6,7 +6,7 @@ import PIL.Image ...@@ -6,7 +6,7 @@ import PIL.Image
import torch import torch
from torchvision.transforms import InterpolationMode from torchvision.transforms import InterpolationMode
from ._datapoint import Datapoint, FillTypeJIT from ._datapoint import _FillTypeJIT, Datapoint
class Mask(Datapoint): class Mask(Datapoint):
...@@ -96,7 +96,7 @@ class Mask(Datapoint): ...@@ -96,7 +96,7 @@ class Mask(Datapoint):
interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST, interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
expand: bool = False, expand: bool = False,
center: Optional[List[float]] = None, center: Optional[List[float]] = None,
fill: FillTypeJIT = None, fill: _FillTypeJIT = None,
) -> Mask: ) -> Mask:
output = self._F.rotate_mask(self.as_subclass(torch.Tensor), angle, expand=expand, center=center, fill=fill) output = self._F.rotate_mask(self.as_subclass(torch.Tensor), angle, expand=expand, center=center, fill=fill)
return Mask.wrap_like(self, output) return Mask.wrap_like(self, output)
...@@ -108,7 +108,7 @@ class Mask(Datapoint): ...@@ -108,7 +108,7 @@ class Mask(Datapoint):
scale: float, scale: float,
shear: List[float], shear: List[float],
interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST, interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
fill: FillTypeJIT = None, fill: _FillTypeJIT = None,
center: Optional[List[float]] = None, center: Optional[List[float]] = None,
) -> Mask: ) -> Mask:
output = self._F.affine_mask( output = self._F.affine_mask(
...@@ -127,7 +127,7 @@ class Mask(Datapoint): ...@@ -127,7 +127,7 @@ class Mask(Datapoint):
startpoints: Optional[List[List[int]]], startpoints: Optional[List[List[int]]],
endpoints: Optional[List[List[int]]], endpoints: Optional[List[List[int]]],
interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST, interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
fill: FillTypeJIT = None, fill: _FillTypeJIT = None,
coefficients: Optional[List[float]] = None, coefficients: Optional[List[float]] = None,
) -> Mask: ) -> Mask:
output = self._F.perspective_mask( output = self._F.perspective_mask(
...@@ -139,7 +139,7 @@ class Mask(Datapoint): ...@@ -139,7 +139,7 @@ class Mask(Datapoint):
self, self,
displacement: torch.Tensor, displacement: torch.Tensor,
interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST, interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
fill: FillTypeJIT = None, fill: _FillTypeJIT = None,
) -> Mask: ) -> Mask:
output = self._F.elastic_mask(self.as_subclass(torch.Tensor), displacement, fill=fill) output = self._F.elastic_mask(self.as_subclass(torch.Tensor), displacement, fill=fill)
return Mask.wrap_like(self, output) return Mask.wrap_like(self, output)
...@@ -5,7 +5,7 @@ from typing import Any, List, Optional, Tuple, Union ...@@ -5,7 +5,7 @@ from typing import Any, List, Optional, Tuple, Union
import torch import torch
from torchvision.transforms.functional import InterpolationMode from torchvision.transforms.functional import InterpolationMode
from ._datapoint import Datapoint, FillTypeJIT from ._datapoint import _FillTypeJIT, Datapoint
class Video(Datapoint): class Video(Datapoint):
...@@ -115,7 +115,7 @@ class Video(Datapoint): ...@@ -115,7 +115,7 @@ class Video(Datapoint):
interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST, interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
expand: bool = False, expand: bool = False,
center: Optional[List[float]] = None, center: Optional[List[float]] = None,
fill: FillTypeJIT = None, fill: _FillTypeJIT = None,
) -> Video: ) -> Video:
output = self._F.rotate_video( output = self._F.rotate_video(
self.as_subclass(torch.Tensor), angle, interpolation=interpolation, expand=expand, fill=fill, center=center self.as_subclass(torch.Tensor), angle, interpolation=interpolation, expand=expand, fill=fill, center=center
...@@ -129,7 +129,7 @@ class Video(Datapoint): ...@@ -129,7 +129,7 @@ class Video(Datapoint):
scale: float, scale: float,
shear: List[float], shear: List[float],
interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST, interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
fill: FillTypeJIT = None, fill: _FillTypeJIT = None,
center: Optional[List[float]] = None, center: Optional[List[float]] = None,
) -> Video: ) -> Video:
output = self._F.affine_video( output = self._F.affine_video(
...@@ -149,7 +149,7 @@ class Video(Datapoint): ...@@ -149,7 +149,7 @@ class Video(Datapoint):
startpoints: Optional[List[List[int]]], startpoints: Optional[List[List[int]]],
endpoints: Optional[List[List[int]]], endpoints: Optional[List[List[int]]],
interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR, interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
fill: FillTypeJIT = None, fill: _FillTypeJIT = None,
coefficients: Optional[List[float]] = None, coefficients: Optional[List[float]] = None,
) -> Video: ) -> Video:
output = self._F.perspective_video( output = self._F.perspective_video(
...@@ -166,7 +166,7 @@ class Video(Datapoint): ...@@ -166,7 +166,7 @@ class Video(Datapoint):
self, self,
displacement: torch.Tensor, displacement: torch.Tensor,
interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR, interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
fill: FillTypeJIT = None, fill: _FillTypeJIT = None,
) -> Video: ) -> Video:
output = self._F.elastic_video( output = self._F.elastic_video(
self.as_subclass(torch.Tensor), displacement, interpolation=interpolation, fill=fill self.as_subclass(torch.Tensor), displacement, interpolation=interpolation, fill=fill
...@@ -232,7 +232,7 @@ class Video(Datapoint): ...@@ -232,7 +232,7 @@ class Video(Datapoint):
return Video.wrap_like(self, output) return Video.wrap_like(self, output)
VideoType = Union[torch.Tensor, Video] _VideoType = Union[torch.Tensor, Video]
VideoTypeJIT = torch.Tensor _VideoTypeJIT = torch.Tensor
TensorVideoType = Union[torch.Tensor, Video] _TensorVideoType = Union[torch.Tensor, Video]
TensorVideoTypeJIT = torch.Tensor _TensorVideoTypeJIT = torch.Tensor
...@@ -119,15 +119,15 @@ class SimpleCopyPaste(Transform): ...@@ -119,15 +119,15 @@ class SimpleCopyPaste(Transform):
def _copy_paste( def _copy_paste(
self, self,
image: datapoints.TensorImageType, image: datapoints._TensorImageType,
target: Dict[str, Any], target: Dict[str, Any],
paste_image: datapoints.TensorImageType, paste_image: datapoints._TensorImageType,
paste_target: Dict[str, Any], paste_target: Dict[str, Any],
random_selection: torch.Tensor, random_selection: torch.Tensor,
blending: bool, blending: bool,
resize_interpolation: F.InterpolationMode, resize_interpolation: F.InterpolationMode,
antialias: Optional[bool], antialias: Optional[bool],
) -> Tuple[datapoints.TensorImageType, Dict[str, Any]]: ) -> Tuple[datapoints._TensorImageType, Dict[str, Any]]:
paste_masks = paste_target["masks"].wrap_like(paste_target["masks"], paste_target["masks"][random_selection]) paste_masks = paste_target["masks"].wrap_like(paste_target["masks"], paste_target["masks"][random_selection])
paste_boxes = paste_target["boxes"].wrap_like(paste_target["boxes"], paste_target["boxes"][random_selection]) paste_boxes = paste_target["boxes"].wrap_like(paste_target["boxes"], paste_target["boxes"][random_selection])
...@@ -199,7 +199,7 @@ class SimpleCopyPaste(Transform): ...@@ -199,7 +199,7 @@ class SimpleCopyPaste(Transform):
def _extract_image_targets( def _extract_image_targets(
self, flat_sample: List[Any] self, flat_sample: List[Any]
) -> Tuple[List[datapoints.TensorImageType], List[Dict[str, Any]]]: ) -> Tuple[List[datapoints._TensorImageType], List[Dict[str, Any]]]:
# fetch all images, bboxes, masks and labels from unstructured input # fetch all images, bboxes, masks and labels from unstructured input
# with List[image], List[BoundingBox], List[Mask], List[Label] # with List[image], List[BoundingBox], List[Mask], List[Label]
images, bboxes, masks, labels = [], [], [], [] images, bboxes, masks, labels = [], [], [], []
...@@ -230,7 +230,7 @@ class SimpleCopyPaste(Transform): ...@@ -230,7 +230,7 @@ class SimpleCopyPaste(Transform):
def _insert_outputs( def _insert_outputs(
self, self,
flat_sample: List[Any], flat_sample: List[Any],
output_images: List[datapoints.TensorImageType], output_images: List[datapoints._TensorImageType],
output_targets: List[Dict[str, Any]], output_targets: List[Dict[str, Any]],
) -> None: ) -> None:
c0, c1, c2, c3 = 0, 0, 0, 0 c0, c1, c2, c3 = 0, 0, 0, 0
......
...@@ -14,7 +14,7 @@ class FixedSizeCrop(Transform): ...@@ -14,7 +14,7 @@ class FixedSizeCrop(Transform):
def __init__( def __init__(
self, self,
size: Union[int, Sequence[int]], size: Union[int, Sequence[int]],
fill: Union[datapoints.FillType, Dict[Type, datapoints.FillType]] = 0, fill: Union[datapoints._FillType, Dict[Type, datapoints._FillType]] = 0,
padding_mode: str = "constant", padding_mode: str = "constant",
) -> None: ) -> None:
super().__init__() super().__init__()
......
...@@ -26,7 +26,7 @@ class PermuteDimensions(Transform): ...@@ -26,7 +26,7 @@ class PermuteDimensions(Transform):
self.dims = dims self.dims = dims
def _transform( def _transform(
self, inpt: Union[datapoints.TensorImageType, datapoints.TensorVideoType], params: Dict[str, Any] self, inpt: Union[datapoints._TensorImageType, datapoints._TensorVideoType], params: Dict[str, Any]
) -> torch.Tensor: ) -> torch.Tensor:
dims = self.dims[type(inpt)] dims = self.dims[type(inpt)]
if dims is None: if dims is None:
...@@ -50,7 +50,7 @@ class TransposeDimensions(Transform): ...@@ -50,7 +50,7 @@ class TransposeDimensions(Transform):
self.dims = dims self.dims = dims
def _transform( def _transform(
self, inpt: Union[datapoints.TensorImageType, datapoints.TensorVideoType], params: Dict[str, Any] self, inpt: Union[datapoints._TensorImageType, datapoints._TensorVideoType], params: Dict[str, Any]
) -> torch.Tensor: ) -> torch.Tensor:
dims = self.dims[type(inpt)] dims = self.dims[type(inpt)]
if dims is None: if dims is None:
......
...@@ -97,8 +97,8 @@ class RandomErasing(_RandomApplyTransform): ...@@ -97,8 +97,8 @@ class RandomErasing(_RandomApplyTransform):
return dict(i=i, j=j, h=h, w=w, v=v) return dict(i=i, j=j, h=h, w=w, v=v)
def _transform( def _transform(
self, inpt: Union[datapoints.ImageType, datapoints.VideoType], params: Dict[str, Any] self, inpt: Union[datapoints._ImageType, datapoints._VideoType], params: Dict[str, Any]
) -> Union[datapoints.ImageType, datapoints.VideoType]: ) -> Union[datapoints._ImageType, datapoints._VideoType]:
if params["v"] is not None: if params["v"] is not None:
inpt = F.erase(inpt, **params, inplace=self.inplace) inpt = F.erase(inpt, **params, inplace=self.inplace)
......
...@@ -20,7 +20,7 @@ class _AutoAugmentBase(Transform): ...@@ -20,7 +20,7 @@ class _AutoAugmentBase(Transform):
self, self,
*, *,
interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST, interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
fill: Union[datapoints.FillType, Dict[Type, datapoints.FillType]] = None, fill: Union[datapoints._FillType, Dict[Type, datapoints._FillType]] = None,
) -> None: ) -> None:
super().__init__() super().__init__()
self.interpolation = _check_interpolation(interpolation) self.interpolation = _check_interpolation(interpolation)
...@@ -35,7 +35,7 @@ class _AutoAugmentBase(Transform): ...@@ -35,7 +35,7 @@ class _AutoAugmentBase(Transform):
self, self,
inputs: Any, inputs: Any,
unsupported_types: Tuple[Type, ...] = (datapoints.BoundingBox, datapoints.Mask), unsupported_types: Tuple[Type, ...] = (datapoints.BoundingBox, datapoints.Mask),
) -> Tuple[Tuple[List[Any], TreeSpec, int], Union[datapoints.ImageType, datapoints.VideoType]]: ) -> Tuple[Tuple[List[Any], TreeSpec, int], Union[datapoints._ImageType, datapoints._VideoType]]:
flat_inputs, spec = tree_flatten(inputs if len(inputs) > 1 else inputs[0]) flat_inputs, spec = tree_flatten(inputs if len(inputs) > 1 else inputs[0])
needs_transform_list = self._needs_transform_list(flat_inputs) needs_transform_list = self._needs_transform_list(flat_inputs)
...@@ -68,7 +68,7 @@ class _AutoAugmentBase(Transform): ...@@ -68,7 +68,7 @@ class _AutoAugmentBase(Transform):
def _unflatten_and_insert_image_or_video( def _unflatten_and_insert_image_or_video(
self, self,
flat_inputs_with_spec: Tuple[List[Any], TreeSpec, int], flat_inputs_with_spec: Tuple[List[Any], TreeSpec, int],
image_or_video: Union[datapoints.ImageType, datapoints.VideoType], image_or_video: Union[datapoints._ImageType, datapoints._VideoType],
) -> Any: ) -> Any:
flat_inputs, spec, idx = flat_inputs_with_spec flat_inputs, spec, idx = flat_inputs_with_spec
flat_inputs[idx] = image_or_video flat_inputs[idx] = image_or_video
...@@ -76,12 +76,12 @@ class _AutoAugmentBase(Transform): ...@@ -76,12 +76,12 @@ class _AutoAugmentBase(Transform):
def _apply_image_or_video_transform( def _apply_image_or_video_transform(
self, self,
image: Union[datapoints.ImageType, datapoints.VideoType], image: Union[datapoints._ImageType, datapoints._VideoType],
transform_id: str, transform_id: str,
magnitude: float, magnitude: float,
interpolation: Union[InterpolationMode, int], interpolation: Union[InterpolationMode, int],
fill: Dict[Type, datapoints.FillTypeJIT], fill: Dict[Type, datapoints._FillTypeJIT],
) -> Union[datapoints.ImageType, datapoints.VideoType]: ) -> Union[datapoints._ImageType, datapoints._VideoType]:
fill_ = fill[type(image)] fill_ = fill[type(image)]
if transform_id == "Identity": if transform_id == "Identity":
...@@ -194,7 +194,7 @@ class AutoAugment(_AutoAugmentBase): ...@@ -194,7 +194,7 @@ class AutoAugment(_AutoAugmentBase):
self, self,
policy: AutoAugmentPolicy = AutoAugmentPolicy.IMAGENET, policy: AutoAugmentPolicy = AutoAugmentPolicy.IMAGENET,
interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST, interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
fill: Union[datapoints.FillType, Dict[Type, datapoints.FillType]] = None, fill: Union[datapoints._FillType, Dict[Type, datapoints._FillType]] = None,
) -> None: ) -> None:
super().__init__(interpolation=interpolation, fill=fill) super().__init__(interpolation=interpolation, fill=fill)
self.policy = policy self.policy = policy
...@@ -351,7 +351,7 @@ class RandAugment(_AutoAugmentBase): ...@@ -351,7 +351,7 @@ class RandAugment(_AutoAugmentBase):
magnitude: int = 9, magnitude: int = 9,
num_magnitude_bins: int = 31, num_magnitude_bins: int = 31,
interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST, interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
fill: Union[datapoints.FillType, Dict[Type, datapoints.FillType]] = None, fill: Union[datapoints._FillType, Dict[Type, datapoints._FillType]] = None,
) -> None: ) -> None:
super().__init__(interpolation=interpolation, fill=fill) super().__init__(interpolation=interpolation, fill=fill)
self.num_ops = num_ops self.num_ops = num_ops
...@@ -404,7 +404,7 @@ class TrivialAugmentWide(_AutoAugmentBase): ...@@ -404,7 +404,7 @@ class TrivialAugmentWide(_AutoAugmentBase):
self, self,
num_magnitude_bins: int = 31, num_magnitude_bins: int = 31,
interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST, interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
fill: Union[datapoints.FillType, Dict[Type, datapoints.FillType]] = None, fill: Union[datapoints._FillType, Dict[Type, datapoints._FillType]] = None,
): ):
super().__init__(interpolation=interpolation, fill=fill) super().__init__(interpolation=interpolation, fill=fill)
self.num_magnitude_bins = num_magnitude_bins self.num_magnitude_bins = num_magnitude_bins
...@@ -462,7 +462,7 @@ class AugMix(_AutoAugmentBase): ...@@ -462,7 +462,7 @@ class AugMix(_AutoAugmentBase):
alpha: float = 1.0, alpha: float = 1.0,
all_ops: bool = True, all_ops: bool = True,
interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR, interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
fill: Union[datapoints.FillType, Dict[Type, datapoints.FillType]] = None, fill: Union[datapoints._FillType, Dict[Type, datapoints._FillType]] = None,
) -> None: ) -> None:
super().__init__(interpolation=interpolation, fill=fill) super().__init__(interpolation=interpolation, fill=fill)
self._PARAMETER_MAX = 10 self._PARAMETER_MAX = 10
......
...@@ -163,8 +163,8 @@ class RandomPhotometricDistort(Transform): ...@@ -163,8 +163,8 @@ class RandomPhotometricDistort(Transform):
) )
def _permute_channels( def _permute_channels(
self, inpt: Union[datapoints.ImageType, datapoints.VideoType], permutation: torch.Tensor self, inpt: Union[datapoints._ImageType, datapoints._VideoType], permutation: torch.Tensor
) -> Union[datapoints.ImageType, datapoints.VideoType]: ) -> Union[datapoints._ImageType, datapoints._VideoType]:
orig_inpt = inpt orig_inpt = inpt
if isinstance(orig_inpt, PIL.Image.Image): if isinstance(orig_inpt, PIL.Image.Image):
...@@ -179,8 +179,8 @@ class RandomPhotometricDistort(Transform): ...@@ -179,8 +179,8 @@ class RandomPhotometricDistort(Transform):
return output return output
def _transform( def _transform(
self, inpt: Union[datapoints.ImageType, datapoints.VideoType], params: Dict[str, Any] self, inpt: Union[datapoints._ImageType, datapoints._VideoType], params: Dict[str, Any]
) -> Union[datapoints.ImageType, datapoints.VideoType]: ) -> Union[datapoints._ImageType, datapoints._VideoType]:
if params["brightness"]: if params["brightness"]:
inpt = F.adjust_brightness( inpt = F.adjust_brightness(
inpt, brightness_factor=ColorJitter._generate_value(self.brightness[0], self.brightness[1]) inpt, brightness_factor=ColorJitter._generate_value(self.brightness[0], self.brightness[1])
......
...@@ -160,7 +160,7 @@ class RandomResizedCrop(Transform): ...@@ -160,7 +160,7 @@ class RandomResizedCrop(Transform):
) )
ImageOrVideoTypeJIT = Union[datapoints.ImageTypeJIT, datapoints.VideoTypeJIT] ImageOrVideoTypeJIT = Union[datapoints._ImageTypeJIT, datapoints._VideoTypeJIT]
class FiveCrop(Transform): class FiveCrop(Transform):
...@@ -232,7 +232,7 @@ class TenCrop(Transform): ...@@ -232,7 +232,7 @@ class TenCrop(Transform):
raise TypeError(f"BoundingBox'es and Mask's are not supported by {type(self).__name__}()") raise TypeError(f"BoundingBox'es and Mask's are not supported by {type(self).__name__}()")
def _transform( def _transform(
self, inpt: Union[datapoints.ImageType, datapoints.VideoType], params: Dict[str, Any] self, inpt: Union[datapoints._ImageType, datapoints._VideoType], params: Dict[str, Any]
) -> Tuple[ ) -> Tuple[
ImageOrVideoTypeJIT, ImageOrVideoTypeJIT,
ImageOrVideoTypeJIT, ImageOrVideoTypeJIT,
...@@ -264,7 +264,7 @@ class Pad(Transform): ...@@ -264,7 +264,7 @@ class Pad(Transform):
def __init__( def __init__(
self, self,
padding: Union[int, Sequence[int]], padding: Union[int, Sequence[int]],
fill: Union[datapoints.FillType, Dict[Type, datapoints.FillType]] = 0, fill: Union[datapoints._FillType, Dict[Type, datapoints._FillType]] = 0,
padding_mode: Literal["constant", "edge", "reflect", "symmetric"] = "constant", padding_mode: Literal["constant", "edge", "reflect", "symmetric"] = "constant",
) -> None: ) -> None:
super().__init__() super().__init__()
...@@ -287,7 +287,7 @@ class Pad(Transform): ...@@ -287,7 +287,7 @@ class Pad(Transform):
class RandomZoomOut(_RandomApplyTransform): class RandomZoomOut(_RandomApplyTransform):
def __init__( def __init__(
self, self,
fill: Union[datapoints.FillType, Dict[Type, datapoints.FillType]] = 0, fill: Union[datapoints._FillType, Dict[Type, datapoints._FillType]] = 0,
side_range: Sequence[float] = (1.0, 4.0), side_range: Sequence[float] = (1.0, 4.0),
p: float = 0.5, p: float = 0.5,
) -> None: ) -> None:
...@@ -330,7 +330,7 @@ class RandomRotation(Transform): ...@@ -330,7 +330,7 @@ class RandomRotation(Transform):
degrees: Union[numbers.Number, Sequence], degrees: Union[numbers.Number, Sequence],
interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST, interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
expand: bool = False, expand: bool = False,
fill: Union[datapoints.FillType, Dict[Type, datapoints.FillType]] = 0, fill: Union[datapoints._FillType, Dict[Type, datapoints._FillType]] = 0,
center: Optional[List[float]] = None, center: Optional[List[float]] = None,
) -> None: ) -> None:
super().__init__() super().__init__()
...@@ -371,7 +371,7 @@ class RandomAffine(Transform): ...@@ -371,7 +371,7 @@ class RandomAffine(Transform):
scale: Optional[Sequence[float]] = None, scale: Optional[Sequence[float]] = None,
shear: Optional[Union[int, float, Sequence[float]]] = None, shear: Optional[Union[int, float, Sequence[float]]] = None,
interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST, interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
fill: Union[datapoints.FillType, Dict[Type, datapoints.FillType]] = 0, fill: Union[datapoints._FillType, Dict[Type, datapoints._FillType]] = 0,
center: Optional[List[float]] = None, center: Optional[List[float]] = None,
) -> None: ) -> None:
super().__init__() super().__init__()
...@@ -464,7 +464,7 @@ class RandomCrop(Transform): ...@@ -464,7 +464,7 @@ class RandomCrop(Transform):
size: Union[int, Sequence[int]], size: Union[int, Sequence[int]],
padding: Optional[Union[int, Sequence[int]]] = None, padding: Optional[Union[int, Sequence[int]]] = None,
pad_if_needed: bool = False, pad_if_needed: bool = False,
fill: Union[datapoints.FillType, Dict[Type, datapoints.FillType]] = 0, fill: Union[datapoints._FillType, Dict[Type, datapoints._FillType]] = 0,
padding_mode: Literal["constant", "edge", "reflect", "symmetric"] = "constant", padding_mode: Literal["constant", "edge", "reflect", "symmetric"] = "constant",
) -> None: ) -> None:
super().__init__() super().__init__()
...@@ -556,7 +556,7 @@ class RandomPerspective(_RandomApplyTransform): ...@@ -556,7 +556,7 @@ class RandomPerspective(_RandomApplyTransform):
def __init__( def __init__(
self, self,
distortion_scale: float = 0.5, distortion_scale: float = 0.5,
fill: Union[datapoints.FillType, Dict[Type, datapoints.FillType]] = 0, fill: Union[datapoints._FillType, Dict[Type, datapoints._FillType]] = 0,
interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR, interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
p: float = 0.5, p: float = 0.5,
) -> None: ) -> None:
...@@ -618,7 +618,7 @@ class ElasticTransform(Transform): ...@@ -618,7 +618,7 @@ class ElasticTransform(Transform):
self, self,
alpha: Union[float, Sequence[float]] = 50.0, alpha: Union[float, Sequence[float]] = 50.0,
sigma: Union[float, Sequence[float]] = 5.0, sigma: Union[float, Sequence[float]] = 5.0,
fill: Union[datapoints.FillType, Dict[Type, datapoints.FillType]] = 0, fill: Union[datapoints._FillType, Dict[Type, datapoints._FillType]] = 0,
interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR, interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
) -> None: ) -> None:
super().__init__() super().__init__()
......
...@@ -31,8 +31,8 @@ class ConvertDtype(Transform): ...@@ -31,8 +31,8 @@ class ConvertDtype(Transform):
self.dtype = dtype self.dtype = dtype
def _transform( def _transform(
self, inpt: Union[datapoints.TensorImageType, datapoints.TensorVideoType], params: Dict[str, Any] self, inpt: Union[datapoints._TensorImageType, datapoints._TensorVideoType], params: Dict[str, Any]
) -> Union[datapoints.TensorImageType, datapoints.TensorVideoType]: ) -> Union[datapoints._TensorImageType, datapoints._TensorVideoType]:
return F.convert_dtype(inpt, self.dtype) return F.convert_dtype(inpt, self.dtype)
......
...@@ -119,7 +119,7 @@ class Normalize(Transform): ...@@ -119,7 +119,7 @@ class Normalize(Transform):
raise TypeError(f"{type(self).__name__}() does not support PIL images.") raise TypeError(f"{type(self).__name__}() does not support PIL images.")
def _transform( def _transform(
self, inpt: Union[datapoints.TensorImageType, datapoints.TensorVideoType], params: Dict[str, Any] self, inpt: Union[datapoints._TensorImageType, datapoints._TensorVideoType], params: Dict[str, Any]
) -> Any: ) -> Any:
return F.normalize(inpt, mean=self.mean, std=self.std, inplace=self.inplace) return F.normalize(inpt, mean=self.mean, std=self.std, inplace=self.inplace)
......
...@@ -13,5 +13,5 @@ class UniformTemporalSubsample(Transform): ...@@ -13,5 +13,5 @@ class UniformTemporalSubsample(Transform):
super().__init__() super().__init__()
self.num_samples = num_samples self.num_samples = num_samples
def _transform(self, inpt: datapoints.VideoType, params: Dict[str, Any]) -> datapoints.VideoType: def _transform(self, inpt: datapoints._VideoType, params: Dict[str, Any]) -> datapoints._VideoType:
return F.uniform_temporal_subsample(inpt, self.num_samples) return F.uniform_temporal_subsample(inpt, self.num_samples)
...@@ -4,7 +4,7 @@ from collections import defaultdict ...@@ -4,7 +4,7 @@ from collections import defaultdict
from typing import Any, Dict, Literal, Sequence, Type, TypeVar, Union from typing import Any, Dict, Literal, Sequence, Type, TypeVar, Union
from torchvision import datapoints from torchvision import datapoints
from torchvision.datapoints._datapoint import FillType, FillTypeJIT from torchvision.datapoints._datapoint import _FillType, _FillTypeJIT
from torchvision.transforms.transforms import _check_sequence_input, _setup_angle, _setup_size # noqa: F401 from torchvision.transforms.transforms import _check_sequence_input, _setup_angle, _setup_size # noqa: F401
...@@ -26,7 +26,7 @@ def _setup_float_or_seq(arg: Union[float, Sequence[float]], name: str, req_size: ...@@ -26,7 +26,7 @@ def _setup_float_or_seq(arg: Union[float, Sequence[float]], name: str, req_size:
return arg return arg
def _check_fill_arg(fill: Union[FillType, Dict[Type, FillType]]) -> None: def _check_fill_arg(fill: Union[_FillType, Dict[Type, _FillType]]) -> None:
if isinstance(fill, dict): if isinstance(fill, dict):
for key, value in fill.items(): for key, value in fill.items():
# Check key for type # Check key for type
...@@ -52,7 +52,7 @@ def _get_defaultdict(default: T) -> Dict[Any, T]: ...@@ -52,7 +52,7 @@ def _get_defaultdict(default: T) -> Dict[Any, T]:
return defaultdict(functools.partial(_default_arg, default)) return defaultdict(functools.partial(_default_arg, default))
def _convert_fill_arg(fill: datapoints.FillType) -> datapoints.FillTypeJIT: def _convert_fill_arg(fill: datapoints._FillType) -> datapoints._FillTypeJIT:
# Fill = 0 is not equivalent to None, https://github.com/pytorch/vision/issues/6517 # Fill = 0 is not equivalent to None, https://github.com/pytorch/vision/issues/6517
# So, we can't reassign fill to 0 # So, we can't reassign fill to 0
# if fill is None: # if fill is None:
...@@ -65,7 +65,7 @@ def _convert_fill_arg(fill: datapoints.FillType) -> datapoints.FillTypeJIT: ...@@ -65,7 +65,7 @@ def _convert_fill_arg(fill: datapoints.FillType) -> datapoints.FillTypeJIT:
return fill # type: ignore[return-value] return fill # type: ignore[return-value]
def _setup_fill_arg(fill: Union[FillType, Dict[Type, FillType]]) -> Dict[Type, FillTypeJIT]: def _setup_fill_arg(fill: Union[_FillType, Dict[Type, _FillType]]) -> Dict[Type, _FillTypeJIT]:
_check_fill_arg(fill) _check_fill_arg(fill)
if isinstance(fill, dict): if isinstance(fill, dict):
......
...@@ -36,14 +36,14 @@ def erase_video( ...@@ -36,14 +36,14 @@ def erase_video(
def erase( def erase(
inpt: Union[datapoints.ImageTypeJIT, datapoints.VideoTypeJIT], inpt: Union[datapoints._ImageTypeJIT, datapoints._VideoTypeJIT],
i: int, i: int,
j: int, j: int,
h: int, h: int,
w: int, w: int,
v: torch.Tensor, v: torch.Tensor,
inplace: bool = False, inplace: bool = False,
) -> Union[datapoints.ImageTypeJIT, datapoints.VideoTypeJIT]: ) -> Union[datapoints._ImageTypeJIT, datapoints._VideoTypeJIT]:
if not torch.jit.is_scripting(): if not torch.jit.is_scripting():
_log_api_usage_once(erase) _log_api_usage_once(erase)
......
...@@ -37,8 +37,8 @@ rgb_to_grayscale_image_pil = _FP.to_grayscale ...@@ -37,8 +37,8 @@ rgb_to_grayscale_image_pil = _FP.to_grayscale
def rgb_to_grayscale( def rgb_to_grayscale(
inpt: Union[datapoints.ImageTypeJIT, datapoints.VideoTypeJIT], num_output_channels: int = 1 inpt: Union[datapoints._ImageTypeJIT, datapoints._VideoTypeJIT], num_output_channels: int = 1
) -> Union[datapoints.ImageTypeJIT, datapoints.VideoTypeJIT]: ) -> Union[datapoints._ImageTypeJIT, datapoints._VideoTypeJIT]:
if not torch.jit.is_scripting(): if not torch.jit.is_scripting():
_log_api_usage_once(rgb_to_grayscale) _log_api_usage_once(rgb_to_grayscale)
if num_output_channels not in (1, 3): if num_output_channels not in (1, 3):
...@@ -85,7 +85,7 @@ def adjust_brightness_video(video: torch.Tensor, brightness_factor: float) -> to ...@@ -85,7 +85,7 @@ def adjust_brightness_video(video: torch.Tensor, brightness_factor: float) -> to
return adjust_brightness_image_tensor(video, brightness_factor=brightness_factor) return adjust_brightness_image_tensor(video, brightness_factor=brightness_factor)
def adjust_brightness(inpt: datapoints.InputTypeJIT, brightness_factor: float) -> datapoints.InputTypeJIT: def adjust_brightness(inpt: datapoints._InputTypeJIT, brightness_factor: float) -> datapoints._InputTypeJIT:
if not torch.jit.is_scripting(): if not torch.jit.is_scripting():
_log_api_usage_once(adjust_brightness) _log_api_usage_once(adjust_brightness)
...@@ -127,7 +127,7 @@ def adjust_saturation_video(video: torch.Tensor, saturation_factor: float) -> to ...@@ -127,7 +127,7 @@ def adjust_saturation_video(video: torch.Tensor, saturation_factor: float) -> to
return adjust_saturation_image_tensor(video, saturation_factor=saturation_factor) return adjust_saturation_image_tensor(video, saturation_factor=saturation_factor)
def adjust_saturation(inpt: datapoints.InputTypeJIT, saturation_factor: float) -> datapoints.InputTypeJIT: def adjust_saturation(inpt: datapoints._InputTypeJIT, saturation_factor: float) -> datapoints._InputTypeJIT:
if not torch.jit.is_scripting(): if not torch.jit.is_scripting():
_log_api_usage_once(adjust_saturation) _log_api_usage_once(adjust_saturation)
...@@ -171,7 +171,7 @@ def adjust_contrast_video(video: torch.Tensor, contrast_factor: float) -> torch. ...@@ -171,7 +171,7 @@ def adjust_contrast_video(video: torch.Tensor, contrast_factor: float) -> torch.
return adjust_contrast_image_tensor(video, contrast_factor=contrast_factor) return adjust_contrast_image_tensor(video, contrast_factor=contrast_factor)
def adjust_contrast(inpt: datapoints.InputTypeJIT, contrast_factor: float) -> datapoints.InputTypeJIT: def adjust_contrast(inpt: datapoints._InputTypeJIT, contrast_factor: float) -> datapoints._InputTypeJIT:
if not torch.jit.is_scripting(): if not torch.jit.is_scripting():
_log_api_usage_once(adjust_contrast) _log_api_usage_once(adjust_contrast)
...@@ -247,7 +247,7 @@ def adjust_sharpness_video(video: torch.Tensor, sharpness_factor: float) -> torc ...@@ -247,7 +247,7 @@ def adjust_sharpness_video(video: torch.Tensor, sharpness_factor: float) -> torc
return adjust_sharpness_image_tensor(video, sharpness_factor=sharpness_factor) return adjust_sharpness_image_tensor(video, sharpness_factor=sharpness_factor)
def adjust_sharpness(inpt: datapoints.InputTypeJIT, sharpness_factor: float) -> datapoints.InputTypeJIT: def adjust_sharpness(inpt: datapoints._InputTypeJIT, sharpness_factor: float) -> datapoints._InputTypeJIT:
if not torch.jit.is_scripting(): if not torch.jit.is_scripting():
_log_api_usage_once(adjust_sharpness) _log_api_usage_once(adjust_sharpness)
...@@ -364,7 +364,7 @@ def adjust_hue_video(video: torch.Tensor, hue_factor: float) -> torch.Tensor: ...@@ -364,7 +364,7 @@ def adjust_hue_video(video: torch.Tensor, hue_factor: float) -> torch.Tensor:
return adjust_hue_image_tensor(video, hue_factor=hue_factor) return adjust_hue_image_tensor(video, hue_factor=hue_factor)
def adjust_hue(inpt: datapoints.InputTypeJIT, hue_factor: float) -> datapoints.InputTypeJIT: def adjust_hue(inpt: datapoints._InputTypeJIT, hue_factor: float) -> datapoints._InputTypeJIT:
if not torch.jit.is_scripting(): if not torch.jit.is_scripting():
_log_api_usage_once(adjust_hue) _log_api_usage_once(adjust_hue)
...@@ -407,7 +407,7 @@ def adjust_gamma_video(video: torch.Tensor, gamma: float, gain: float = 1) -> to ...@@ -407,7 +407,7 @@ def adjust_gamma_video(video: torch.Tensor, gamma: float, gain: float = 1) -> to
return adjust_gamma_image_tensor(video, gamma=gamma, gain=gain) return adjust_gamma_image_tensor(video, gamma=gamma, gain=gain)
def adjust_gamma(inpt: datapoints.InputTypeJIT, gamma: float, gain: float = 1) -> datapoints.InputTypeJIT: def adjust_gamma(inpt: datapoints._InputTypeJIT, gamma: float, gain: float = 1) -> datapoints._InputTypeJIT:
if not torch.jit.is_scripting(): if not torch.jit.is_scripting():
_log_api_usage_once(adjust_gamma) _log_api_usage_once(adjust_gamma)
...@@ -444,7 +444,7 @@ def posterize_video(video: torch.Tensor, bits: int) -> torch.Tensor: ...@@ -444,7 +444,7 @@ def posterize_video(video: torch.Tensor, bits: int) -> torch.Tensor:
return posterize_image_tensor(video, bits=bits) return posterize_image_tensor(video, bits=bits)
def posterize(inpt: datapoints.InputTypeJIT, bits: int) -> datapoints.InputTypeJIT: def posterize(inpt: datapoints._InputTypeJIT, bits: int) -> datapoints._InputTypeJIT:
if not torch.jit.is_scripting(): if not torch.jit.is_scripting():
_log_api_usage_once(posterize) _log_api_usage_once(posterize)
...@@ -475,7 +475,7 @@ def solarize_video(video: torch.Tensor, threshold: float) -> torch.Tensor: ...@@ -475,7 +475,7 @@ def solarize_video(video: torch.Tensor, threshold: float) -> torch.Tensor:
return solarize_image_tensor(video, threshold=threshold) return solarize_image_tensor(video, threshold=threshold)
def solarize(inpt: datapoints.InputTypeJIT, threshold: float) -> datapoints.InputTypeJIT: def solarize(inpt: datapoints._InputTypeJIT, threshold: float) -> datapoints._InputTypeJIT:
if not torch.jit.is_scripting(): if not torch.jit.is_scripting():
_log_api_usage_once(solarize) _log_api_usage_once(solarize)
...@@ -528,7 +528,7 @@ def autocontrast_video(video: torch.Tensor) -> torch.Tensor: ...@@ -528,7 +528,7 @@ def autocontrast_video(video: torch.Tensor) -> torch.Tensor:
return autocontrast_image_tensor(video) return autocontrast_image_tensor(video)
def autocontrast(inpt: datapoints.InputTypeJIT) -> datapoints.InputTypeJIT: def autocontrast(inpt: datapoints._InputTypeJIT) -> datapoints._InputTypeJIT:
if not torch.jit.is_scripting(): if not torch.jit.is_scripting():
_log_api_usage_once(autocontrast) _log_api_usage_once(autocontrast)
...@@ -621,7 +621,7 @@ def equalize_video(video: torch.Tensor) -> torch.Tensor: ...@@ -621,7 +621,7 @@ def equalize_video(video: torch.Tensor) -> torch.Tensor:
return equalize_image_tensor(video) return equalize_image_tensor(video)
def equalize(inpt: datapoints.InputTypeJIT) -> datapoints.InputTypeJIT: def equalize(inpt: datapoints._InputTypeJIT) -> datapoints._InputTypeJIT:
if not torch.jit.is_scripting(): if not torch.jit.is_scripting():
_log_api_usage_once(equalize) _log_api_usage_once(equalize)
...@@ -655,7 +655,7 @@ def invert_video(video: torch.Tensor) -> torch.Tensor: ...@@ -655,7 +655,7 @@ def invert_video(video: torch.Tensor) -> torch.Tensor:
return invert_image_tensor(video) return invert_image_tensor(video)
def invert(inpt: datapoints.InputTypeJIT) -> datapoints.InputTypeJIT: def invert(inpt: datapoints._InputTypeJIT) -> datapoints._InputTypeJIT:
if not torch.jit.is_scripting(): if not torch.jit.is_scripting():
_log_api_usage_once(invert) _log_api_usage_once(invert)
......
...@@ -31,7 +31,7 @@ def to_tensor(inpt: Any) -> torch.Tensor: ...@@ -31,7 +31,7 @@ def to_tensor(inpt: Any) -> torch.Tensor:
return _F.to_tensor(inpt) return _F.to_tensor(inpt)
def get_image_size(inpt: Union[datapoints.ImageTypeJIT, datapoints.VideoTypeJIT]) -> List[int]: def get_image_size(inpt: Union[datapoints._ImageTypeJIT, datapoints._VideoTypeJIT]) -> List[int]:
warnings.warn( warnings.warn(
"The function `get_image_size(...)` is deprecated and will be removed in a future release. " "The function `get_image_size(...)` is deprecated and will be removed in a future release. "
"Instead, please use `get_spatial_size(...)` which returns `[h, w]` instead of `[w, h]`." "Instead, please use `get_spatial_size(...)` which returns `[h, w]` instead of `[w, h]`."
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment