"...git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "a7ca03aa85f94574f06576d2155b3ec061fe8d63"
Unverified Commit d4d20f01 authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

make type alias private (#7266)

parent e405f3c3
from ._bounding_box import BoundingBox, BoundingBoxFormat
from ._datapoint import FillType, FillTypeJIT, InputType, InputTypeJIT
from ._image import Image, ImageType, ImageTypeJIT, TensorImageType, TensorImageTypeJIT
from ._datapoint import _FillType, _FillTypeJIT, _InputType, _InputTypeJIT
from ._image import _ImageType, _ImageTypeJIT, _TensorImageType, _TensorImageTypeJIT, Image
from ._mask import Mask
from ._video import TensorVideoType, TensorVideoTypeJIT, Video, VideoType, VideoTypeJIT
from ._video import _TensorVideoType, _TensorVideoTypeJIT, _VideoType, _VideoTypeJIT, Video
from ._dataset_wrapper import wrap_dataset_for_transforms_v2 # type: ignore[attr-defined] # usort: skip
......@@ -6,7 +6,7 @@ import torch
from torchvision._utils import StrEnum
from torchvision.transforms import InterpolationMode # TODO: this needs to be moved out of transforms
from ._datapoint import Datapoint, FillTypeJIT
from ._datapoint import _FillTypeJIT, Datapoint
class BoundingBoxFormat(StrEnum):
......@@ -136,7 +136,7 @@ class BoundingBox(Datapoint):
interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
expand: bool = False,
center: Optional[List[float]] = None,
fill: FillTypeJIT = None,
fill: _FillTypeJIT = None,
) -> BoundingBox:
output, spatial_size = self._F.rotate_bounding_box(
self.as_subclass(torch.Tensor),
......@@ -155,7 +155,7 @@ class BoundingBox(Datapoint):
scale: float,
shear: List[float],
interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
fill: FillTypeJIT = None,
fill: _FillTypeJIT = None,
center: Optional[List[float]] = None,
) -> BoundingBox:
output = self._F.affine_bounding_box(
......@@ -175,7 +175,7 @@ class BoundingBox(Datapoint):
startpoints: Optional[List[List[int]]],
endpoints: Optional[List[List[int]]],
interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
fill: FillTypeJIT = None,
fill: _FillTypeJIT = None,
coefficients: Optional[List[float]] = None,
) -> BoundingBox:
output = self._F.perspective_bounding_box(
......@@ -192,7 +192,7 @@ class BoundingBox(Datapoint):
self,
displacement: torch.Tensor,
interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
fill: FillTypeJIT = None,
fill: _FillTypeJIT = None,
) -> BoundingBox:
output = self._F.elastic_bounding_box(
self.as_subclass(torch.Tensor), self.format, self.spatial_size, displacement=displacement
......
......@@ -11,8 +11,8 @@ from torchvision.transforms import InterpolationMode
D = TypeVar("D", bound="Datapoint")
FillType = Union[int, float, Sequence[int], Sequence[float], None]
FillTypeJIT = Optional[List[float]]
_FillType = Union[int, float, Sequence[int], Sequence[float], None]
_FillTypeJIT = Optional[List[float]]
class Datapoint(torch.Tensor):
......@@ -181,7 +181,7 @@ class Datapoint(torch.Tensor):
interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
expand: bool = False,
center: Optional[List[float]] = None,
fill: FillTypeJIT = None,
fill: _FillTypeJIT = None,
) -> Datapoint:
return self
......@@ -192,7 +192,7 @@ class Datapoint(torch.Tensor):
scale: float,
shear: List[float],
interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
fill: FillTypeJIT = None,
fill: _FillTypeJIT = None,
center: Optional[List[float]] = None,
) -> Datapoint:
return self
......@@ -202,7 +202,7 @@ class Datapoint(torch.Tensor):
startpoints: Optional[List[List[int]]],
endpoints: Optional[List[List[int]]],
interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
fill: FillTypeJIT = None,
fill: _FillTypeJIT = None,
coefficients: Optional[List[float]] = None,
) -> Datapoint:
return self
......@@ -211,7 +211,7 @@ class Datapoint(torch.Tensor):
self,
displacement: torch.Tensor,
interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
fill: FillTypeJIT = None,
fill: _FillTypeJIT = None,
) -> Datapoint:
return self
......@@ -255,5 +255,5 @@ class Datapoint(torch.Tensor):
return self
InputType = Union[torch.Tensor, PIL.Image.Image, Datapoint]
InputTypeJIT = torch.Tensor
_InputType = Union[torch.Tensor, PIL.Image.Image, Datapoint]
_InputTypeJIT = torch.Tensor
......@@ -6,7 +6,7 @@ import PIL.Image
import torch
from torchvision.transforms.functional import InterpolationMode
from ._datapoint import Datapoint, FillTypeJIT
from ._datapoint import _FillTypeJIT, Datapoint
class Image(Datapoint):
......@@ -116,7 +116,7 @@ class Image(Datapoint):
interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
expand: bool = False,
center: Optional[List[float]] = None,
fill: FillTypeJIT = None,
fill: _FillTypeJIT = None,
) -> Image:
output = self._F.rotate_image_tensor(
self.as_subclass(torch.Tensor), angle, interpolation=interpolation, expand=expand, fill=fill, center=center
......@@ -130,7 +130,7 @@ class Image(Datapoint):
scale: float,
shear: List[float],
interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
fill: FillTypeJIT = None,
fill: _FillTypeJIT = None,
center: Optional[List[float]] = None,
) -> Image:
output = self._F.affine_image_tensor(
......@@ -150,7 +150,7 @@ class Image(Datapoint):
startpoints: Optional[List[List[int]]],
endpoints: Optional[List[List[int]]],
interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
fill: FillTypeJIT = None,
fill: _FillTypeJIT = None,
coefficients: Optional[List[float]] = None,
) -> Image:
output = self._F.perspective_image_tensor(
......@@ -167,7 +167,7 @@ class Image(Datapoint):
self,
displacement: torch.Tensor,
interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
fill: FillTypeJIT = None,
fill: _FillTypeJIT = None,
) -> Image:
output = self._F.elastic_image_tensor(
self.as_subclass(torch.Tensor), displacement, interpolation=interpolation, fill=fill
......@@ -241,7 +241,7 @@ class Image(Datapoint):
return Image.wrap_like(self, output)
ImageType = Union[torch.Tensor, PIL.Image.Image, Image]
ImageTypeJIT = torch.Tensor
TensorImageType = Union[torch.Tensor, Image]
TensorImageTypeJIT = torch.Tensor
_ImageType = Union[torch.Tensor, PIL.Image.Image, Image]
_ImageTypeJIT = torch.Tensor
_TensorImageType = Union[torch.Tensor, Image]
_TensorImageTypeJIT = torch.Tensor
......@@ -6,7 +6,7 @@ import PIL.Image
import torch
from torchvision.transforms import InterpolationMode
from ._datapoint import Datapoint, FillTypeJIT
from ._datapoint import _FillTypeJIT, Datapoint
class Mask(Datapoint):
......@@ -96,7 +96,7 @@ class Mask(Datapoint):
interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
expand: bool = False,
center: Optional[List[float]] = None,
fill: FillTypeJIT = None,
fill: _FillTypeJIT = None,
) -> Mask:
output = self._F.rotate_mask(self.as_subclass(torch.Tensor), angle, expand=expand, center=center, fill=fill)
return Mask.wrap_like(self, output)
......@@ -108,7 +108,7 @@ class Mask(Datapoint):
scale: float,
shear: List[float],
interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
fill: FillTypeJIT = None,
fill: _FillTypeJIT = None,
center: Optional[List[float]] = None,
) -> Mask:
output = self._F.affine_mask(
......@@ -127,7 +127,7 @@ class Mask(Datapoint):
startpoints: Optional[List[List[int]]],
endpoints: Optional[List[List[int]]],
interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
fill: FillTypeJIT = None,
fill: _FillTypeJIT = None,
coefficients: Optional[List[float]] = None,
) -> Mask:
output = self._F.perspective_mask(
......@@ -139,7 +139,7 @@ class Mask(Datapoint):
self,
displacement: torch.Tensor,
interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
fill: FillTypeJIT = None,
fill: _FillTypeJIT = None,
) -> Mask:
output = self._F.elastic_mask(self.as_subclass(torch.Tensor), displacement, fill=fill)
return Mask.wrap_like(self, output)
......@@ -5,7 +5,7 @@ from typing import Any, List, Optional, Tuple, Union
import torch
from torchvision.transforms.functional import InterpolationMode
from ._datapoint import Datapoint, FillTypeJIT
from ._datapoint import _FillTypeJIT, Datapoint
class Video(Datapoint):
......@@ -115,7 +115,7 @@ class Video(Datapoint):
interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
expand: bool = False,
center: Optional[List[float]] = None,
fill: FillTypeJIT = None,
fill: _FillTypeJIT = None,
) -> Video:
output = self._F.rotate_video(
self.as_subclass(torch.Tensor), angle, interpolation=interpolation, expand=expand, fill=fill, center=center
......@@ -129,7 +129,7 @@ class Video(Datapoint):
scale: float,
shear: List[float],
interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
fill: FillTypeJIT = None,
fill: _FillTypeJIT = None,
center: Optional[List[float]] = None,
) -> Video:
output = self._F.affine_video(
......@@ -149,7 +149,7 @@ class Video(Datapoint):
startpoints: Optional[List[List[int]]],
endpoints: Optional[List[List[int]]],
interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
fill: FillTypeJIT = None,
fill: _FillTypeJIT = None,
coefficients: Optional[List[float]] = None,
) -> Video:
output = self._F.perspective_video(
......@@ -166,7 +166,7 @@ class Video(Datapoint):
self,
displacement: torch.Tensor,
interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
fill: FillTypeJIT = None,
fill: _FillTypeJIT = None,
) -> Video:
output = self._F.elastic_video(
self.as_subclass(torch.Tensor), displacement, interpolation=interpolation, fill=fill
......@@ -232,7 +232,7 @@ class Video(Datapoint):
return Video.wrap_like(self, output)
VideoType = Union[torch.Tensor, Video]
VideoTypeJIT = torch.Tensor
TensorVideoType = Union[torch.Tensor, Video]
TensorVideoTypeJIT = torch.Tensor
_VideoType = Union[torch.Tensor, Video]
_VideoTypeJIT = torch.Tensor
_TensorVideoType = Union[torch.Tensor, Video]
_TensorVideoTypeJIT = torch.Tensor
......@@ -119,15 +119,15 @@ class SimpleCopyPaste(Transform):
def _copy_paste(
self,
image: datapoints.TensorImageType,
image: datapoints._TensorImageType,
target: Dict[str, Any],
paste_image: datapoints.TensorImageType,
paste_image: datapoints._TensorImageType,
paste_target: Dict[str, Any],
random_selection: torch.Tensor,
blending: bool,
resize_interpolation: F.InterpolationMode,
antialias: Optional[bool],
) -> Tuple[datapoints.TensorImageType, Dict[str, Any]]:
) -> Tuple[datapoints._TensorImageType, Dict[str, Any]]:
paste_masks = paste_target["masks"].wrap_like(paste_target["masks"], paste_target["masks"][random_selection])
paste_boxes = paste_target["boxes"].wrap_like(paste_target["boxes"], paste_target["boxes"][random_selection])
......@@ -199,7 +199,7 @@ class SimpleCopyPaste(Transform):
def _extract_image_targets(
self, flat_sample: List[Any]
) -> Tuple[List[datapoints.TensorImageType], List[Dict[str, Any]]]:
) -> Tuple[List[datapoints._TensorImageType], List[Dict[str, Any]]]:
# fetch all images, bboxes, masks and labels from unstructured input
# with List[image], List[BoundingBox], List[Mask], List[Label]
images, bboxes, masks, labels = [], [], [], []
......@@ -230,7 +230,7 @@ class SimpleCopyPaste(Transform):
def _insert_outputs(
self,
flat_sample: List[Any],
output_images: List[datapoints.TensorImageType],
output_images: List[datapoints._TensorImageType],
output_targets: List[Dict[str, Any]],
) -> None:
c0, c1, c2, c3 = 0, 0, 0, 0
......
......@@ -14,7 +14,7 @@ class FixedSizeCrop(Transform):
def __init__(
self,
size: Union[int, Sequence[int]],
fill: Union[datapoints.FillType, Dict[Type, datapoints.FillType]] = 0,
fill: Union[datapoints._FillType, Dict[Type, datapoints._FillType]] = 0,
padding_mode: str = "constant",
) -> None:
super().__init__()
......
......@@ -26,7 +26,7 @@ class PermuteDimensions(Transform):
self.dims = dims
def _transform(
self, inpt: Union[datapoints.TensorImageType, datapoints.TensorVideoType], params: Dict[str, Any]
self, inpt: Union[datapoints._TensorImageType, datapoints._TensorVideoType], params: Dict[str, Any]
) -> torch.Tensor:
dims = self.dims[type(inpt)]
if dims is None:
......@@ -50,7 +50,7 @@ class TransposeDimensions(Transform):
self.dims = dims
def _transform(
self, inpt: Union[datapoints.TensorImageType, datapoints.TensorVideoType], params: Dict[str, Any]
self, inpt: Union[datapoints._TensorImageType, datapoints._TensorVideoType], params: Dict[str, Any]
) -> torch.Tensor:
dims = self.dims[type(inpt)]
if dims is None:
......
......@@ -97,8 +97,8 @@ class RandomErasing(_RandomApplyTransform):
return dict(i=i, j=j, h=h, w=w, v=v)
def _transform(
self, inpt: Union[datapoints.ImageType, datapoints.VideoType], params: Dict[str, Any]
) -> Union[datapoints.ImageType, datapoints.VideoType]:
self, inpt: Union[datapoints._ImageType, datapoints._VideoType], params: Dict[str, Any]
) -> Union[datapoints._ImageType, datapoints._VideoType]:
if params["v"] is not None:
inpt = F.erase(inpt, **params, inplace=self.inplace)
......
......@@ -20,7 +20,7 @@ class _AutoAugmentBase(Transform):
self,
*,
interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
fill: Union[datapoints.FillType, Dict[Type, datapoints.FillType]] = None,
fill: Union[datapoints._FillType, Dict[Type, datapoints._FillType]] = None,
) -> None:
super().__init__()
self.interpolation = _check_interpolation(interpolation)
......@@ -35,7 +35,7 @@ class _AutoAugmentBase(Transform):
self,
inputs: Any,
unsupported_types: Tuple[Type, ...] = (datapoints.BoundingBox, datapoints.Mask),
) -> Tuple[Tuple[List[Any], TreeSpec, int], Union[datapoints.ImageType, datapoints.VideoType]]:
) -> Tuple[Tuple[List[Any], TreeSpec, int], Union[datapoints._ImageType, datapoints._VideoType]]:
flat_inputs, spec = tree_flatten(inputs if len(inputs) > 1 else inputs[0])
needs_transform_list = self._needs_transform_list(flat_inputs)
......@@ -68,7 +68,7 @@ class _AutoAugmentBase(Transform):
def _unflatten_and_insert_image_or_video(
self,
flat_inputs_with_spec: Tuple[List[Any], TreeSpec, int],
image_or_video: Union[datapoints.ImageType, datapoints.VideoType],
image_or_video: Union[datapoints._ImageType, datapoints._VideoType],
) -> Any:
flat_inputs, spec, idx = flat_inputs_with_spec
flat_inputs[idx] = image_or_video
......@@ -76,12 +76,12 @@ class _AutoAugmentBase(Transform):
def _apply_image_or_video_transform(
self,
image: Union[datapoints.ImageType, datapoints.VideoType],
image: Union[datapoints._ImageType, datapoints._VideoType],
transform_id: str,
magnitude: float,
interpolation: Union[InterpolationMode, int],
fill: Dict[Type, datapoints.FillTypeJIT],
) -> Union[datapoints.ImageType, datapoints.VideoType]:
fill: Dict[Type, datapoints._FillTypeJIT],
) -> Union[datapoints._ImageType, datapoints._VideoType]:
fill_ = fill[type(image)]
if transform_id == "Identity":
......@@ -194,7 +194,7 @@ class AutoAugment(_AutoAugmentBase):
self,
policy: AutoAugmentPolicy = AutoAugmentPolicy.IMAGENET,
interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
fill: Union[datapoints.FillType, Dict[Type, datapoints.FillType]] = None,
fill: Union[datapoints._FillType, Dict[Type, datapoints._FillType]] = None,
) -> None:
super().__init__(interpolation=interpolation, fill=fill)
self.policy = policy
......@@ -351,7 +351,7 @@ class RandAugment(_AutoAugmentBase):
magnitude: int = 9,
num_magnitude_bins: int = 31,
interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
fill: Union[datapoints.FillType, Dict[Type, datapoints.FillType]] = None,
fill: Union[datapoints._FillType, Dict[Type, datapoints._FillType]] = None,
) -> None:
super().__init__(interpolation=interpolation, fill=fill)
self.num_ops = num_ops
......@@ -404,7 +404,7 @@ class TrivialAugmentWide(_AutoAugmentBase):
self,
num_magnitude_bins: int = 31,
interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
fill: Union[datapoints.FillType, Dict[Type, datapoints.FillType]] = None,
fill: Union[datapoints._FillType, Dict[Type, datapoints._FillType]] = None,
):
super().__init__(interpolation=interpolation, fill=fill)
self.num_magnitude_bins = num_magnitude_bins
......@@ -462,7 +462,7 @@ class AugMix(_AutoAugmentBase):
alpha: float = 1.0,
all_ops: bool = True,
interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
fill: Union[datapoints.FillType, Dict[Type, datapoints.FillType]] = None,
fill: Union[datapoints._FillType, Dict[Type, datapoints._FillType]] = None,
) -> None:
super().__init__(interpolation=interpolation, fill=fill)
self._PARAMETER_MAX = 10
......
......@@ -163,8 +163,8 @@ class RandomPhotometricDistort(Transform):
)
def _permute_channels(
self, inpt: Union[datapoints.ImageType, datapoints.VideoType], permutation: torch.Tensor
) -> Union[datapoints.ImageType, datapoints.VideoType]:
self, inpt: Union[datapoints._ImageType, datapoints._VideoType], permutation: torch.Tensor
) -> Union[datapoints._ImageType, datapoints._VideoType]:
orig_inpt = inpt
if isinstance(orig_inpt, PIL.Image.Image):
......@@ -179,8 +179,8 @@ class RandomPhotometricDistort(Transform):
return output
def _transform(
self, inpt: Union[datapoints.ImageType, datapoints.VideoType], params: Dict[str, Any]
) -> Union[datapoints.ImageType, datapoints.VideoType]:
self, inpt: Union[datapoints._ImageType, datapoints._VideoType], params: Dict[str, Any]
) -> Union[datapoints._ImageType, datapoints._VideoType]:
if params["brightness"]:
inpt = F.adjust_brightness(
inpt, brightness_factor=ColorJitter._generate_value(self.brightness[0], self.brightness[1])
......
......@@ -160,7 +160,7 @@ class RandomResizedCrop(Transform):
)
ImageOrVideoTypeJIT = Union[datapoints.ImageTypeJIT, datapoints.VideoTypeJIT]
ImageOrVideoTypeJIT = Union[datapoints._ImageTypeJIT, datapoints._VideoTypeJIT]
class FiveCrop(Transform):
......@@ -232,7 +232,7 @@ class TenCrop(Transform):
raise TypeError(f"BoundingBox'es and Mask's are not supported by {type(self).__name__}()")
def _transform(
self, inpt: Union[datapoints.ImageType, datapoints.VideoType], params: Dict[str, Any]
self, inpt: Union[datapoints._ImageType, datapoints._VideoType], params: Dict[str, Any]
) -> Tuple[
ImageOrVideoTypeJIT,
ImageOrVideoTypeJIT,
......@@ -264,7 +264,7 @@ class Pad(Transform):
def __init__(
self,
padding: Union[int, Sequence[int]],
fill: Union[datapoints.FillType, Dict[Type, datapoints.FillType]] = 0,
fill: Union[datapoints._FillType, Dict[Type, datapoints._FillType]] = 0,
padding_mode: Literal["constant", "edge", "reflect", "symmetric"] = "constant",
) -> None:
super().__init__()
......@@ -287,7 +287,7 @@ class Pad(Transform):
class RandomZoomOut(_RandomApplyTransform):
def __init__(
self,
fill: Union[datapoints.FillType, Dict[Type, datapoints.FillType]] = 0,
fill: Union[datapoints._FillType, Dict[Type, datapoints._FillType]] = 0,
side_range: Sequence[float] = (1.0, 4.0),
p: float = 0.5,
) -> None:
......@@ -330,7 +330,7 @@ class RandomRotation(Transform):
degrees: Union[numbers.Number, Sequence],
interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
expand: bool = False,
fill: Union[datapoints.FillType, Dict[Type, datapoints.FillType]] = 0,
fill: Union[datapoints._FillType, Dict[Type, datapoints._FillType]] = 0,
center: Optional[List[float]] = None,
) -> None:
super().__init__()
......@@ -371,7 +371,7 @@ class RandomAffine(Transform):
scale: Optional[Sequence[float]] = None,
shear: Optional[Union[int, float, Sequence[float]]] = None,
interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
fill: Union[datapoints.FillType, Dict[Type, datapoints.FillType]] = 0,
fill: Union[datapoints._FillType, Dict[Type, datapoints._FillType]] = 0,
center: Optional[List[float]] = None,
) -> None:
super().__init__()
......@@ -464,7 +464,7 @@ class RandomCrop(Transform):
size: Union[int, Sequence[int]],
padding: Optional[Union[int, Sequence[int]]] = None,
pad_if_needed: bool = False,
fill: Union[datapoints.FillType, Dict[Type, datapoints.FillType]] = 0,
fill: Union[datapoints._FillType, Dict[Type, datapoints._FillType]] = 0,
padding_mode: Literal["constant", "edge", "reflect", "symmetric"] = "constant",
) -> None:
super().__init__()
......@@ -556,7 +556,7 @@ class RandomPerspective(_RandomApplyTransform):
def __init__(
self,
distortion_scale: float = 0.5,
fill: Union[datapoints.FillType, Dict[Type, datapoints.FillType]] = 0,
fill: Union[datapoints._FillType, Dict[Type, datapoints._FillType]] = 0,
interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
p: float = 0.5,
) -> None:
......@@ -618,7 +618,7 @@ class ElasticTransform(Transform):
self,
alpha: Union[float, Sequence[float]] = 50.0,
sigma: Union[float, Sequence[float]] = 5.0,
fill: Union[datapoints.FillType, Dict[Type, datapoints.FillType]] = 0,
fill: Union[datapoints._FillType, Dict[Type, datapoints._FillType]] = 0,
interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
) -> None:
super().__init__()
......
......@@ -31,8 +31,8 @@ class ConvertDtype(Transform):
self.dtype = dtype
def _transform(
self, inpt: Union[datapoints.TensorImageType, datapoints.TensorVideoType], params: Dict[str, Any]
) -> Union[datapoints.TensorImageType, datapoints.TensorVideoType]:
self, inpt: Union[datapoints._TensorImageType, datapoints._TensorVideoType], params: Dict[str, Any]
) -> Union[datapoints._TensorImageType, datapoints._TensorVideoType]:
return F.convert_dtype(inpt, self.dtype)
......
......@@ -119,7 +119,7 @@ class Normalize(Transform):
raise TypeError(f"{type(self).__name__}() does not support PIL images.")
def _transform(
self, inpt: Union[datapoints.TensorImageType, datapoints.TensorVideoType], params: Dict[str, Any]
self, inpt: Union[datapoints._TensorImageType, datapoints._TensorVideoType], params: Dict[str, Any]
) -> Any:
return F.normalize(inpt, mean=self.mean, std=self.std, inplace=self.inplace)
......
......@@ -13,5 +13,5 @@ class UniformTemporalSubsample(Transform):
super().__init__()
self.num_samples = num_samples
def _transform(self, inpt: datapoints.VideoType, params: Dict[str, Any]) -> datapoints.VideoType:
def _transform(self, inpt: datapoints._VideoType, params: Dict[str, Any]) -> datapoints._VideoType:
return F.uniform_temporal_subsample(inpt, self.num_samples)
......@@ -4,7 +4,7 @@ from collections import defaultdict
from typing import Any, Dict, Literal, Sequence, Type, TypeVar, Union
from torchvision import datapoints
from torchvision.datapoints._datapoint import FillType, FillTypeJIT
from torchvision.datapoints._datapoint import _FillType, _FillTypeJIT
from torchvision.transforms.transforms import _check_sequence_input, _setup_angle, _setup_size # noqa: F401
......@@ -26,7 +26,7 @@ def _setup_float_or_seq(arg: Union[float, Sequence[float]], name: str, req_size:
return arg
def _check_fill_arg(fill: Union[FillType, Dict[Type, FillType]]) -> None:
def _check_fill_arg(fill: Union[_FillType, Dict[Type, _FillType]]) -> None:
if isinstance(fill, dict):
for key, value in fill.items():
# Check key for type
......@@ -52,7 +52,7 @@ def _get_defaultdict(default: T) -> Dict[Any, T]:
return defaultdict(functools.partial(_default_arg, default))
def _convert_fill_arg(fill: datapoints.FillType) -> datapoints.FillTypeJIT:
def _convert_fill_arg(fill: datapoints._FillType) -> datapoints._FillTypeJIT:
# Fill = 0 is not equivalent to None, https://github.com/pytorch/vision/issues/6517
# So, we can't reassign fill to 0
# if fill is None:
......@@ -65,7 +65,7 @@ def _convert_fill_arg(fill: datapoints.FillType) -> datapoints.FillTypeJIT:
return fill # type: ignore[return-value]
def _setup_fill_arg(fill: Union[FillType, Dict[Type, FillType]]) -> Dict[Type, FillTypeJIT]:
def _setup_fill_arg(fill: Union[_FillType, Dict[Type, _FillType]]) -> Dict[Type, _FillTypeJIT]:
_check_fill_arg(fill)
if isinstance(fill, dict):
......
......@@ -36,14 +36,14 @@ def erase_video(
def erase(
inpt: Union[datapoints.ImageTypeJIT, datapoints.VideoTypeJIT],
inpt: Union[datapoints._ImageTypeJIT, datapoints._VideoTypeJIT],
i: int,
j: int,
h: int,
w: int,
v: torch.Tensor,
inplace: bool = False,
) -> Union[datapoints.ImageTypeJIT, datapoints.VideoTypeJIT]:
) -> Union[datapoints._ImageTypeJIT, datapoints._VideoTypeJIT]:
if not torch.jit.is_scripting():
_log_api_usage_once(erase)
......
......@@ -37,8 +37,8 @@ rgb_to_grayscale_image_pil = _FP.to_grayscale
def rgb_to_grayscale(
inpt: Union[datapoints.ImageTypeJIT, datapoints.VideoTypeJIT], num_output_channels: int = 1
) -> Union[datapoints.ImageTypeJIT, datapoints.VideoTypeJIT]:
inpt: Union[datapoints._ImageTypeJIT, datapoints._VideoTypeJIT], num_output_channels: int = 1
) -> Union[datapoints._ImageTypeJIT, datapoints._VideoTypeJIT]:
if not torch.jit.is_scripting():
_log_api_usage_once(rgb_to_grayscale)
if num_output_channels not in (1, 3):
......@@ -85,7 +85,7 @@ def adjust_brightness_video(video: torch.Tensor, brightness_factor: float) -> to
return adjust_brightness_image_tensor(video, brightness_factor=brightness_factor)
def adjust_brightness(inpt: datapoints.InputTypeJIT, brightness_factor: float) -> datapoints.InputTypeJIT:
def adjust_brightness(inpt: datapoints._InputTypeJIT, brightness_factor: float) -> datapoints._InputTypeJIT:
if not torch.jit.is_scripting():
_log_api_usage_once(adjust_brightness)
......@@ -127,7 +127,7 @@ def adjust_saturation_video(video: torch.Tensor, saturation_factor: float) -> to
return adjust_saturation_image_tensor(video, saturation_factor=saturation_factor)
def adjust_saturation(inpt: datapoints.InputTypeJIT, saturation_factor: float) -> datapoints.InputTypeJIT:
def adjust_saturation(inpt: datapoints._InputTypeJIT, saturation_factor: float) -> datapoints._InputTypeJIT:
if not torch.jit.is_scripting():
_log_api_usage_once(adjust_saturation)
......@@ -171,7 +171,7 @@ def adjust_contrast_video(video: torch.Tensor, contrast_factor: float) -> torch.
return adjust_contrast_image_tensor(video, contrast_factor=contrast_factor)
def adjust_contrast(inpt: datapoints.InputTypeJIT, contrast_factor: float) -> datapoints.InputTypeJIT:
def adjust_contrast(inpt: datapoints._InputTypeJIT, contrast_factor: float) -> datapoints._InputTypeJIT:
if not torch.jit.is_scripting():
_log_api_usage_once(adjust_contrast)
......@@ -247,7 +247,7 @@ def adjust_sharpness_video(video: torch.Tensor, sharpness_factor: float) -> torc
return adjust_sharpness_image_tensor(video, sharpness_factor=sharpness_factor)
def adjust_sharpness(inpt: datapoints.InputTypeJIT, sharpness_factor: float) -> datapoints.InputTypeJIT:
def adjust_sharpness(inpt: datapoints._InputTypeJIT, sharpness_factor: float) -> datapoints._InputTypeJIT:
if not torch.jit.is_scripting():
_log_api_usage_once(adjust_sharpness)
......@@ -364,7 +364,7 @@ def adjust_hue_video(video: torch.Tensor, hue_factor: float) -> torch.Tensor:
return adjust_hue_image_tensor(video, hue_factor=hue_factor)
def adjust_hue(inpt: datapoints.InputTypeJIT, hue_factor: float) -> datapoints.InputTypeJIT:
def adjust_hue(inpt: datapoints._InputTypeJIT, hue_factor: float) -> datapoints._InputTypeJIT:
if not torch.jit.is_scripting():
_log_api_usage_once(adjust_hue)
......@@ -407,7 +407,7 @@ def adjust_gamma_video(video: torch.Tensor, gamma: float, gain: float = 1) -> to
return adjust_gamma_image_tensor(video, gamma=gamma, gain=gain)
def adjust_gamma(inpt: datapoints.InputTypeJIT, gamma: float, gain: float = 1) -> datapoints.InputTypeJIT:
def adjust_gamma(inpt: datapoints._InputTypeJIT, gamma: float, gain: float = 1) -> datapoints._InputTypeJIT:
if not torch.jit.is_scripting():
_log_api_usage_once(adjust_gamma)
......@@ -444,7 +444,7 @@ def posterize_video(video: torch.Tensor, bits: int) -> torch.Tensor:
return posterize_image_tensor(video, bits=bits)
def posterize(inpt: datapoints.InputTypeJIT, bits: int) -> datapoints.InputTypeJIT:
def posterize(inpt: datapoints._InputTypeJIT, bits: int) -> datapoints._InputTypeJIT:
if not torch.jit.is_scripting():
_log_api_usage_once(posterize)
......@@ -475,7 +475,7 @@ def solarize_video(video: torch.Tensor, threshold: float) -> torch.Tensor:
return solarize_image_tensor(video, threshold=threshold)
def solarize(inpt: datapoints.InputTypeJIT, threshold: float) -> datapoints.InputTypeJIT:
def solarize(inpt: datapoints._InputTypeJIT, threshold: float) -> datapoints._InputTypeJIT:
if not torch.jit.is_scripting():
_log_api_usage_once(solarize)
......@@ -528,7 +528,7 @@ def autocontrast_video(video: torch.Tensor) -> torch.Tensor:
return autocontrast_image_tensor(video)
def autocontrast(inpt: datapoints.InputTypeJIT) -> datapoints.InputTypeJIT:
def autocontrast(inpt: datapoints._InputTypeJIT) -> datapoints._InputTypeJIT:
if not torch.jit.is_scripting():
_log_api_usage_once(autocontrast)
......@@ -621,7 +621,7 @@ def equalize_video(video: torch.Tensor) -> torch.Tensor:
return equalize_image_tensor(video)
def equalize(inpt: datapoints.InputTypeJIT) -> datapoints.InputTypeJIT:
def equalize(inpt: datapoints._InputTypeJIT) -> datapoints._InputTypeJIT:
if not torch.jit.is_scripting():
_log_api_usage_once(equalize)
......@@ -655,7 +655,7 @@ def invert_video(video: torch.Tensor) -> torch.Tensor:
return invert_image_tensor(video)
def invert(inpt: datapoints.InputTypeJIT) -> datapoints.InputTypeJIT:
def invert(inpt: datapoints._InputTypeJIT) -> datapoints._InputTypeJIT:
if not torch.jit.is_scripting():
_log_api_usage_once(invert)
......
......@@ -31,7 +31,7 @@ def to_tensor(inpt: Any) -> torch.Tensor:
return _F.to_tensor(inpt)
def get_image_size(inpt: Union[datapoints.ImageTypeJIT, datapoints.VideoTypeJIT]) -> List[int]:
def get_image_size(inpt: Union[datapoints._ImageTypeJIT, datapoints._VideoTypeJIT]) -> List[int]:
warnings.warn(
"The function `get_image_size(...)` is deprecated and will be removed in a future release. "
"Instead, please use `get_spatial_size(...)` which returns `[h, w]` instead of `[w, h]`."
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment