"vscode:/vscode.git/clone" did not exist on "8cd071b6e9af48710d82192d68fe4e41d041d7e4"
Unverified Commit ca012d39 authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

make PIL kernels private (#7831)

parent cdbbd666
...@@ -10,7 +10,7 @@ from torchvision.transforms import functional as _F ...@@ -10,7 +10,7 @@ from torchvision.transforms import functional as _F
def to_tensor(inpt: Any) -> torch.Tensor: def to_tensor(inpt: Any) -> torch.Tensor:
warnings.warn( warnings.warn(
"The function `to_tensor(...)` is deprecated and will be removed in a future release. " "The function `to_tensor(...)` is deprecated and will be removed in a future release. "
"Instead, please use `to_image_tensor(...)` followed by `to_dtype(..., dtype=torch.float32, scale=True)`." "Instead, please use `to_image(...)` followed by `to_dtype(..., dtype=torch.float32, scale=True)`."
) )
return _F.to_tensor(inpt) return _F.to_tensor(inpt)
......
...@@ -23,7 +23,7 @@ from torchvision.transforms.functional import ( ...@@ -23,7 +23,7 @@ from torchvision.transforms.functional import (
from torchvision.utils import _log_api_usage_once from torchvision.utils import _log_api_usage_once
from ._meta import clamp_bounding_boxes, convert_format_bounding_boxes, get_size_image_pil from ._meta import _get_size_image_pil, clamp_bounding_boxes, convert_format_bounding_boxes
from ._utils import _FillTypeJIT, _get_kernel, _register_five_ten_crop_kernel_internal, _register_kernel_internal from ._utils import _FillTypeJIT, _get_kernel, _register_five_ten_crop_kernel_internal, _register_kernel_internal
...@@ -41,7 +41,7 @@ def _check_interpolation(interpolation: Union[InterpolationMode, int]) -> Interp ...@@ -41,7 +41,7 @@ def _check_interpolation(interpolation: Union[InterpolationMode, int]) -> Interp
def horizontal_flip(inpt: torch.Tensor) -> torch.Tensor: def horizontal_flip(inpt: torch.Tensor) -> torch.Tensor:
if torch.jit.is_scripting(): if torch.jit.is_scripting():
return horizontal_flip_image_tensor(inpt) return horizontal_flip_image(inpt)
_log_api_usage_once(horizontal_flip) _log_api_usage_once(horizontal_flip)
...@@ -51,18 +51,18 @@ def horizontal_flip(inpt: torch.Tensor) -> torch.Tensor: ...@@ -51,18 +51,18 @@ def horizontal_flip(inpt: torch.Tensor) -> torch.Tensor:
@_register_kernel_internal(horizontal_flip, torch.Tensor) @_register_kernel_internal(horizontal_flip, torch.Tensor)
@_register_kernel_internal(horizontal_flip, datapoints.Image) @_register_kernel_internal(horizontal_flip, datapoints.Image)
def horizontal_flip_image_tensor(image: torch.Tensor) -> torch.Tensor: def horizontal_flip_image(image: torch.Tensor) -> torch.Tensor:
return image.flip(-1) return image.flip(-1)
@_register_kernel_internal(horizontal_flip, PIL.Image.Image) @_register_kernel_internal(horizontal_flip, PIL.Image.Image)
def horizontal_flip_image_pil(image: PIL.Image.Image) -> PIL.Image.Image: def _horizontal_flip_image_pil(image: PIL.Image.Image) -> PIL.Image.Image:
return _FP.hflip(image) return _FP.hflip(image)
@_register_kernel_internal(horizontal_flip, datapoints.Mask) @_register_kernel_internal(horizontal_flip, datapoints.Mask)
def horizontal_flip_mask(mask: torch.Tensor) -> torch.Tensor: def horizontal_flip_mask(mask: torch.Tensor) -> torch.Tensor:
return horizontal_flip_image_tensor(mask) return horizontal_flip_image(mask)
def horizontal_flip_bounding_boxes( def horizontal_flip_bounding_boxes(
...@@ -92,12 +92,12 @@ def _horizontal_flip_bounding_boxes_dispatch(inpt: datapoints.BoundingBoxes) -> ...@@ -92,12 +92,12 @@ def _horizontal_flip_bounding_boxes_dispatch(inpt: datapoints.BoundingBoxes) ->
@_register_kernel_internal(horizontal_flip, datapoints.Video) @_register_kernel_internal(horizontal_flip, datapoints.Video)
def horizontal_flip_video(video: torch.Tensor) -> torch.Tensor: def horizontal_flip_video(video: torch.Tensor) -> torch.Tensor:
return horizontal_flip_image_tensor(video) return horizontal_flip_image(video)
def vertical_flip(inpt: torch.Tensor) -> torch.Tensor: def vertical_flip(inpt: torch.Tensor) -> torch.Tensor:
if torch.jit.is_scripting(): if torch.jit.is_scripting():
return vertical_flip_image_tensor(inpt) return vertical_flip_image(inpt)
_log_api_usage_once(vertical_flip) _log_api_usage_once(vertical_flip)
...@@ -107,18 +107,18 @@ def vertical_flip(inpt: torch.Tensor) -> torch.Tensor: ...@@ -107,18 +107,18 @@ def vertical_flip(inpt: torch.Tensor) -> torch.Tensor:
@_register_kernel_internal(vertical_flip, torch.Tensor) @_register_kernel_internal(vertical_flip, torch.Tensor)
@_register_kernel_internal(vertical_flip, datapoints.Image) @_register_kernel_internal(vertical_flip, datapoints.Image)
def vertical_flip_image_tensor(image: torch.Tensor) -> torch.Tensor: def vertical_flip_image(image: torch.Tensor) -> torch.Tensor:
return image.flip(-2) return image.flip(-2)
@_register_kernel_internal(vertical_flip, PIL.Image.Image) @_register_kernel_internal(vertical_flip, PIL.Image.Image)
def vertical_flip_image_pil(image: PIL.Image) -> PIL.Image: def _vertical_flip_image_pil(image: PIL.Image) -> PIL.Image:
return _FP.vflip(image) return _FP.vflip(image)
@_register_kernel_internal(vertical_flip, datapoints.Mask) @_register_kernel_internal(vertical_flip, datapoints.Mask)
def vertical_flip_mask(mask: torch.Tensor) -> torch.Tensor: def vertical_flip_mask(mask: torch.Tensor) -> torch.Tensor:
return vertical_flip_image_tensor(mask) return vertical_flip_image(mask)
def vertical_flip_bounding_boxes( def vertical_flip_bounding_boxes(
...@@ -148,7 +148,7 @@ def _vertical_flip_bounding_boxes_dispatch(inpt: datapoints.BoundingBoxes) -> da ...@@ -148,7 +148,7 @@ def _vertical_flip_bounding_boxes_dispatch(inpt: datapoints.BoundingBoxes) -> da
@_register_kernel_internal(vertical_flip, datapoints.Video) @_register_kernel_internal(vertical_flip, datapoints.Video)
def vertical_flip_video(video: torch.Tensor) -> torch.Tensor: def vertical_flip_video(video: torch.Tensor) -> torch.Tensor:
return vertical_flip_image_tensor(video) return vertical_flip_image(video)
# We changed the names to align them with the transforms, i.e. `RandomHorizontalFlip`. Still, `hflip` and `vflip` are # We changed the names to align them with the transforms, i.e. `RandomHorizontalFlip`. Still, `hflip` and `vflip` are
...@@ -178,7 +178,7 @@ def resize( ...@@ -178,7 +178,7 @@ def resize(
antialias: Optional[Union[str, bool]] = "warn", antialias: Optional[Union[str, bool]] = "warn",
) -> torch.Tensor: ) -> torch.Tensor:
if torch.jit.is_scripting(): if torch.jit.is_scripting():
return resize_image_tensor(inpt, size=size, interpolation=interpolation, max_size=max_size, antialias=antialias) return resize_image(inpt, size=size, interpolation=interpolation, max_size=max_size, antialias=antialias)
_log_api_usage_once(resize) _log_api_usage_once(resize)
...@@ -188,7 +188,7 @@ def resize( ...@@ -188,7 +188,7 @@ def resize(
@_register_kernel_internal(resize, torch.Tensor) @_register_kernel_internal(resize, torch.Tensor)
@_register_kernel_internal(resize, datapoints.Image) @_register_kernel_internal(resize, datapoints.Image)
def resize_image_tensor( def resize_image(
image: torch.Tensor, image: torch.Tensor,
size: List[int], size: List[int],
interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR, interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
...@@ -267,7 +267,7 @@ def resize_image_tensor( ...@@ -267,7 +267,7 @@ def resize_image_tensor(
return image.reshape(shape[:-3] + (num_channels, new_height, new_width)) return image.reshape(shape[:-3] + (num_channels, new_height, new_width))
def resize_image_pil( def _resize_image_pil(
image: PIL.Image.Image, image: PIL.Image.Image,
size: Union[Sequence[int], int], size: Union[Sequence[int], int],
interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR, interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
...@@ -289,7 +289,7 @@ def resize_image_pil( ...@@ -289,7 +289,7 @@ def resize_image_pil(
@_register_kernel_internal(resize, PIL.Image.Image) @_register_kernel_internal(resize, PIL.Image.Image)
def _resize_image_pil_dispatch( def __resize_image_pil_dispatch(
image: PIL.Image.Image, image: PIL.Image.Image,
size: Union[Sequence[int], int], size: Union[Sequence[int], int],
interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR, interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
...@@ -298,7 +298,7 @@ def _resize_image_pil_dispatch( ...@@ -298,7 +298,7 @@ def _resize_image_pil_dispatch(
) -> PIL.Image.Image: ) -> PIL.Image.Image:
if antialias is False: if antialias is False:
warnings.warn("Anti-alias option is always applied for PIL Image input. Argument antialias is ignored.") warnings.warn("Anti-alias option is always applied for PIL Image input. Argument antialias is ignored.")
return resize_image_pil(image, size=size, interpolation=interpolation, max_size=max_size) return _resize_image_pil(image, size=size, interpolation=interpolation, max_size=max_size)
def resize_mask(mask: torch.Tensor, size: List[int], max_size: Optional[int] = None) -> torch.Tensor: def resize_mask(mask: torch.Tensor, size: List[int], max_size: Optional[int] = None) -> torch.Tensor:
...@@ -308,7 +308,7 @@ def resize_mask(mask: torch.Tensor, size: List[int], max_size: Optional[int] = N ...@@ -308,7 +308,7 @@ def resize_mask(mask: torch.Tensor, size: List[int], max_size: Optional[int] = N
else: else:
needs_squeeze = False needs_squeeze = False
output = resize_image_tensor(mask, size=size, interpolation=InterpolationMode.NEAREST, max_size=max_size) output = resize_image(mask, size=size, interpolation=InterpolationMode.NEAREST, max_size=max_size)
if needs_squeeze: if needs_squeeze:
output = output.squeeze(0) output = output.squeeze(0)
...@@ -360,7 +360,7 @@ def resize_video( ...@@ -360,7 +360,7 @@ def resize_video(
max_size: Optional[int] = None, max_size: Optional[int] = None,
antialias: Optional[Union[str, bool]] = "warn", antialias: Optional[Union[str, bool]] = "warn",
) -> torch.Tensor: ) -> torch.Tensor:
return resize_image_tensor(video, size=size, interpolation=interpolation, max_size=max_size, antialias=antialias) return resize_image(video, size=size, interpolation=interpolation, max_size=max_size, antialias=antialias)
def affine( def affine(
...@@ -374,7 +374,7 @@ def affine( ...@@ -374,7 +374,7 @@ def affine(
center: Optional[List[float]] = None, center: Optional[List[float]] = None,
) -> torch.Tensor: ) -> torch.Tensor:
if torch.jit.is_scripting(): if torch.jit.is_scripting():
return affine_image_tensor( return affine_image(
inpt, inpt,
angle=angle, angle=angle,
translate=translate, translate=translate,
...@@ -648,7 +648,7 @@ def _affine_grid( ...@@ -648,7 +648,7 @@ def _affine_grid(
@_register_kernel_internal(affine, torch.Tensor) @_register_kernel_internal(affine, torch.Tensor)
@_register_kernel_internal(affine, datapoints.Image) @_register_kernel_internal(affine, datapoints.Image)
def affine_image_tensor( def affine_image(
image: torch.Tensor, image: torch.Tensor,
angle: Union[int, float], angle: Union[int, float],
translate: List[float], translate: List[float],
...@@ -700,7 +700,7 @@ def affine_image_tensor( ...@@ -700,7 +700,7 @@ def affine_image_tensor(
@_register_kernel_internal(affine, PIL.Image.Image) @_register_kernel_internal(affine, PIL.Image.Image)
def affine_image_pil( def _affine_image_pil(
image: PIL.Image.Image, image: PIL.Image.Image,
angle: Union[int, float], angle: Union[int, float],
translate: List[float], translate: List[float],
...@@ -717,7 +717,7 @@ def affine_image_pil( ...@@ -717,7 +717,7 @@ def affine_image_pil(
# it is visually better to estimate the center without 0.5 offset # it is visually better to estimate the center without 0.5 offset
# otherwise image rotated by 90 degrees is shifted vs output image of torch.rot90 or F_t.affine # otherwise image rotated by 90 degrees is shifted vs output image of torch.rot90 or F_t.affine
if center is None: if center is None:
height, width = get_size_image_pil(image) height, width = _get_size_image_pil(image)
center = [width * 0.5, height * 0.5] center = [width * 0.5, height * 0.5]
matrix = _get_inverse_affine_matrix(center, angle, translate, scale, shear) matrix = _get_inverse_affine_matrix(center, angle, translate, scale, shear)
...@@ -875,7 +875,7 @@ def affine_mask( ...@@ -875,7 +875,7 @@ def affine_mask(
else: else:
needs_squeeze = False needs_squeeze = False
output = affine_image_tensor( output = affine_image(
mask, mask,
angle=angle, angle=angle,
translate=translate, translate=translate,
...@@ -926,7 +926,7 @@ def affine_video( ...@@ -926,7 +926,7 @@ def affine_video(
fill: _FillTypeJIT = None, fill: _FillTypeJIT = None,
center: Optional[List[float]] = None, center: Optional[List[float]] = None,
) -> torch.Tensor: ) -> torch.Tensor:
return affine_image_tensor( return affine_image(
video, video,
angle=angle, angle=angle,
translate=translate, translate=translate,
...@@ -947,9 +947,7 @@ def rotate( ...@@ -947,9 +947,7 @@ def rotate(
fill: _FillTypeJIT = None, fill: _FillTypeJIT = None,
) -> torch.Tensor: ) -> torch.Tensor:
if torch.jit.is_scripting(): if torch.jit.is_scripting():
return rotate_image_tensor( return rotate_image(inpt, angle=angle, interpolation=interpolation, expand=expand, fill=fill, center=center)
inpt, angle=angle, interpolation=interpolation, expand=expand, fill=fill, center=center
)
_log_api_usage_once(rotate) _log_api_usage_once(rotate)
...@@ -959,7 +957,7 @@ def rotate( ...@@ -959,7 +957,7 @@ def rotate(
@_register_kernel_internal(rotate, torch.Tensor) @_register_kernel_internal(rotate, torch.Tensor)
@_register_kernel_internal(rotate, datapoints.Image) @_register_kernel_internal(rotate, datapoints.Image)
def rotate_image_tensor( def rotate_image(
image: torch.Tensor, image: torch.Tensor,
angle: float, angle: float,
interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST, interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
...@@ -1004,7 +1002,7 @@ def rotate_image_tensor( ...@@ -1004,7 +1002,7 @@ def rotate_image_tensor(
@_register_kernel_internal(rotate, PIL.Image.Image) @_register_kernel_internal(rotate, PIL.Image.Image)
def rotate_image_pil( def _rotate_image_pil(
image: PIL.Image.Image, image: PIL.Image.Image,
angle: float, angle: float,
interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST, interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
...@@ -1074,7 +1072,7 @@ def rotate_mask( ...@@ -1074,7 +1072,7 @@ def rotate_mask(
else: else:
needs_squeeze = False needs_squeeze = False
output = rotate_image_tensor( output = rotate_image(
mask, mask,
angle=angle, angle=angle,
expand=expand, expand=expand,
...@@ -1111,7 +1109,7 @@ def rotate_video( ...@@ -1111,7 +1109,7 @@ def rotate_video(
center: Optional[List[float]] = None, center: Optional[List[float]] = None,
fill: _FillTypeJIT = None, fill: _FillTypeJIT = None,
) -> torch.Tensor: ) -> torch.Tensor:
return rotate_image_tensor(video, angle, interpolation=interpolation, expand=expand, fill=fill, center=center) return rotate_image(video, angle, interpolation=interpolation, expand=expand, fill=fill, center=center)
def pad( def pad(
...@@ -1121,7 +1119,7 @@ def pad( ...@@ -1121,7 +1119,7 @@ def pad(
padding_mode: str = "constant", padding_mode: str = "constant",
) -> torch.Tensor: ) -> torch.Tensor:
if torch.jit.is_scripting(): if torch.jit.is_scripting():
return pad_image_tensor(inpt, padding=padding, fill=fill, padding_mode=padding_mode) return pad_image(inpt, padding=padding, fill=fill, padding_mode=padding_mode)
_log_api_usage_once(pad) _log_api_usage_once(pad)
...@@ -1155,7 +1153,7 @@ def _parse_pad_padding(padding: Union[int, List[int]]) -> List[int]: ...@@ -1155,7 +1153,7 @@ def _parse_pad_padding(padding: Union[int, List[int]]) -> List[int]:
@_register_kernel_internal(pad, torch.Tensor) @_register_kernel_internal(pad, torch.Tensor)
@_register_kernel_internal(pad, datapoints.Image) @_register_kernel_internal(pad, datapoints.Image)
def pad_image_tensor( def pad_image(
image: torch.Tensor, image: torch.Tensor,
padding: List[int], padding: List[int],
fill: Optional[Union[int, float, List[float]]] = None, fill: Optional[Union[int, float, List[float]]] = None,
...@@ -1253,7 +1251,7 @@ def _pad_with_vector_fill( ...@@ -1253,7 +1251,7 @@ def _pad_with_vector_fill(
return output return output
pad_image_pil = _register_kernel_internal(pad, PIL.Image.Image)(_FP.pad) _pad_image_pil = _register_kernel_internal(pad, PIL.Image.Image)(_FP.pad)
@_register_kernel_internal(pad, datapoints.Mask) @_register_kernel_internal(pad, datapoints.Mask)
...@@ -1275,7 +1273,7 @@ def pad_mask( ...@@ -1275,7 +1273,7 @@ def pad_mask(
else: else:
needs_squeeze = False needs_squeeze = False
output = pad_image_tensor(mask, padding=padding, fill=fill, padding_mode=padding_mode) output = pad_image(mask, padding=padding, fill=fill, padding_mode=padding_mode)
if needs_squeeze: if needs_squeeze:
output = output.squeeze(0) output = output.squeeze(0)
...@@ -1331,12 +1329,12 @@ def pad_video( ...@@ -1331,12 +1329,12 @@ def pad_video(
fill: Optional[Union[int, float, List[float]]] = None, fill: Optional[Union[int, float, List[float]]] = None,
padding_mode: str = "constant", padding_mode: str = "constant",
) -> torch.Tensor: ) -> torch.Tensor:
return pad_image_tensor(video, padding, fill=fill, padding_mode=padding_mode) return pad_image(video, padding, fill=fill, padding_mode=padding_mode)
def crop(inpt: torch.Tensor, top: int, left: int, height: int, width: int) -> torch.Tensor: def crop(inpt: torch.Tensor, top: int, left: int, height: int, width: int) -> torch.Tensor:
if torch.jit.is_scripting(): if torch.jit.is_scripting():
return crop_image_tensor(inpt, top=top, left=left, height=height, width=width) return crop_image(inpt, top=top, left=left, height=height, width=width)
_log_api_usage_once(crop) _log_api_usage_once(crop)
...@@ -1346,7 +1344,7 @@ def crop(inpt: torch.Tensor, top: int, left: int, height: int, width: int) -> to ...@@ -1346,7 +1344,7 @@ def crop(inpt: torch.Tensor, top: int, left: int, height: int, width: int) -> to
@_register_kernel_internal(crop, torch.Tensor) @_register_kernel_internal(crop, torch.Tensor)
@_register_kernel_internal(crop, datapoints.Image) @_register_kernel_internal(crop, datapoints.Image)
def crop_image_tensor(image: torch.Tensor, top: int, left: int, height: int, width: int) -> torch.Tensor: def crop_image(image: torch.Tensor, top: int, left: int, height: int, width: int) -> torch.Tensor:
h, w = image.shape[-2:] h, w = image.shape[-2:]
right = left + width right = left + width
...@@ -1364,8 +1362,8 @@ def crop_image_tensor(image: torch.Tensor, top: int, left: int, height: int, wid ...@@ -1364,8 +1362,8 @@ def crop_image_tensor(image: torch.Tensor, top: int, left: int, height: int, wid
return image[..., top:bottom, left:right] return image[..., top:bottom, left:right]
crop_image_pil = _FP.crop _crop_image_pil = _FP.crop
_register_kernel_internal(crop, PIL.Image.Image)(crop_image_pil) _register_kernel_internal(crop, PIL.Image.Image)(_crop_image_pil)
def crop_bounding_boxes( def crop_bounding_boxes(
...@@ -1407,7 +1405,7 @@ def crop_mask(mask: torch.Tensor, top: int, left: int, height: int, width: int) ...@@ -1407,7 +1405,7 @@ def crop_mask(mask: torch.Tensor, top: int, left: int, height: int, width: int)
else: else:
needs_squeeze = False needs_squeeze = False
output = crop_image_tensor(mask, top, left, height, width) output = crop_image(mask, top, left, height, width)
if needs_squeeze: if needs_squeeze:
output = output.squeeze(0) output = output.squeeze(0)
...@@ -1417,7 +1415,7 @@ def crop_mask(mask: torch.Tensor, top: int, left: int, height: int, width: int) ...@@ -1417,7 +1415,7 @@ def crop_mask(mask: torch.Tensor, top: int, left: int, height: int, width: int)
@_register_kernel_internal(crop, datapoints.Video) @_register_kernel_internal(crop, datapoints.Video)
def crop_video(video: torch.Tensor, top: int, left: int, height: int, width: int) -> torch.Tensor: def crop_video(video: torch.Tensor, top: int, left: int, height: int, width: int) -> torch.Tensor:
return crop_image_tensor(video, top, left, height, width) return crop_image(video, top, left, height, width)
def perspective( def perspective(
...@@ -1429,7 +1427,7 @@ def perspective( ...@@ -1429,7 +1427,7 @@ def perspective(
coefficients: Optional[List[float]] = None, coefficients: Optional[List[float]] = None,
) -> torch.Tensor: ) -> torch.Tensor:
if torch.jit.is_scripting(): if torch.jit.is_scripting():
return perspective_image_tensor( return perspective_image(
inpt, inpt,
startpoints=startpoints, startpoints=startpoints,
endpoints=endpoints, endpoints=endpoints,
...@@ -1500,7 +1498,7 @@ def _perspective_coefficients( ...@@ -1500,7 +1498,7 @@ def _perspective_coefficients(
@_register_kernel_internal(perspective, torch.Tensor) @_register_kernel_internal(perspective, torch.Tensor)
@_register_kernel_internal(perspective, datapoints.Image) @_register_kernel_internal(perspective, datapoints.Image)
def perspective_image_tensor( def perspective_image(
image: torch.Tensor, image: torch.Tensor,
startpoints: Optional[List[List[int]]], startpoints: Optional[List[List[int]]],
endpoints: Optional[List[List[int]]], endpoints: Optional[List[List[int]]],
...@@ -1547,7 +1545,7 @@ def perspective_image_tensor( ...@@ -1547,7 +1545,7 @@ def perspective_image_tensor(
@_register_kernel_internal(perspective, PIL.Image.Image) @_register_kernel_internal(perspective, PIL.Image.Image)
def perspective_image_pil( def _perspective_image_pil(
image: PIL.Image.Image, image: PIL.Image.Image,
startpoints: Optional[List[List[int]]], startpoints: Optional[List[List[int]]],
endpoints: Optional[List[List[int]]], endpoints: Optional[List[List[int]]],
...@@ -1686,7 +1684,7 @@ def perspective_mask( ...@@ -1686,7 +1684,7 @@ def perspective_mask(
else: else:
needs_squeeze = False needs_squeeze = False
output = perspective_image_tensor( output = perspective_image(
mask, startpoints, endpoints, interpolation=InterpolationMode.NEAREST, fill=fill, coefficients=coefficients mask, startpoints, endpoints, interpolation=InterpolationMode.NEAREST, fill=fill, coefficients=coefficients
) )
...@@ -1724,7 +1722,7 @@ def perspective_video( ...@@ -1724,7 +1722,7 @@ def perspective_video(
fill: _FillTypeJIT = None, fill: _FillTypeJIT = None,
coefficients: Optional[List[float]] = None, coefficients: Optional[List[float]] = None,
) -> torch.Tensor: ) -> torch.Tensor:
return perspective_image_tensor( return perspective_image(
video, startpoints, endpoints, interpolation=interpolation, fill=fill, coefficients=coefficients video, startpoints, endpoints, interpolation=interpolation, fill=fill, coefficients=coefficients
) )
...@@ -1736,7 +1734,7 @@ def elastic( ...@@ -1736,7 +1734,7 @@ def elastic(
fill: _FillTypeJIT = None, fill: _FillTypeJIT = None,
) -> torch.Tensor: ) -> torch.Tensor:
if torch.jit.is_scripting(): if torch.jit.is_scripting():
return elastic_image_tensor(inpt, displacement=displacement, interpolation=interpolation, fill=fill) return elastic_image(inpt, displacement=displacement, interpolation=interpolation, fill=fill)
_log_api_usage_once(elastic) _log_api_usage_once(elastic)
...@@ -1749,7 +1747,7 @@ elastic_transform = elastic ...@@ -1749,7 +1747,7 @@ elastic_transform = elastic
@_register_kernel_internal(elastic, torch.Tensor) @_register_kernel_internal(elastic, torch.Tensor)
@_register_kernel_internal(elastic, datapoints.Image) @_register_kernel_internal(elastic, datapoints.Image)
def elastic_image_tensor( def elastic_image(
image: torch.Tensor, image: torch.Tensor,
displacement: torch.Tensor, displacement: torch.Tensor,
interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR, interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
...@@ -1809,14 +1807,14 @@ def elastic_image_tensor( ...@@ -1809,14 +1807,14 @@ def elastic_image_tensor(
@_register_kernel_internal(elastic, PIL.Image.Image) @_register_kernel_internal(elastic, PIL.Image.Image)
def elastic_image_pil( def _elastic_image_pil(
image: PIL.Image.Image, image: PIL.Image.Image,
displacement: torch.Tensor, displacement: torch.Tensor,
interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR, interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
fill: _FillTypeJIT = None, fill: _FillTypeJIT = None,
) -> PIL.Image.Image: ) -> PIL.Image.Image:
t_img = pil_to_tensor(image) t_img = pil_to_tensor(image)
output = elastic_image_tensor(t_img, displacement, interpolation=interpolation, fill=fill) output = elastic_image(t_img, displacement, interpolation=interpolation, fill=fill)
return to_pil_image(output, mode=image.mode) return to_pil_image(output, mode=image.mode)
...@@ -1910,7 +1908,7 @@ def elastic_mask( ...@@ -1910,7 +1908,7 @@ def elastic_mask(
else: else:
needs_squeeze = False needs_squeeze = False
output = elastic_image_tensor(mask, displacement=displacement, interpolation=InterpolationMode.NEAREST, fill=fill) output = elastic_image(mask, displacement=displacement, interpolation=InterpolationMode.NEAREST, fill=fill)
if needs_squeeze: if needs_squeeze:
output = output.squeeze(0) output = output.squeeze(0)
...@@ -1933,12 +1931,12 @@ def elastic_video( ...@@ -1933,12 +1931,12 @@ def elastic_video(
interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR, interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
fill: _FillTypeJIT = None, fill: _FillTypeJIT = None,
) -> torch.Tensor: ) -> torch.Tensor:
return elastic_image_tensor(video, displacement, interpolation=interpolation, fill=fill) return elastic_image(video, displacement, interpolation=interpolation, fill=fill)
def center_crop(inpt: torch.Tensor, output_size: List[int]) -> torch.Tensor: def center_crop(inpt: torch.Tensor, output_size: List[int]) -> torch.Tensor:
if torch.jit.is_scripting(): if torch.jit.is_scripting():
return center_crop_image_tensor(inpt, output_size=output_size) return center_crop_image(inpt, output_size=output_size)
_log_api_usage_once(center_crop) _log_api_usage_once(center_crop)
...@@ -1975,7 +1973,7 @@ def _center_crop_compute_crop_anchor( ...@@ -1975,7 +1973,7 @@ def _center_crop_compute_crop_anchor(
@_register_kernel_internal(center_crop, torch.Tensor) @_register_kernel_internal(center_crop, torch.Tensor)
@_register_kernel_internal(center_crop, datapoints.Image) @_register_kernel_internal(center_crop, datapoints.Image)
def center_crop_image_tensor(image: torch.Tensor, output_size: List[int]) -> torch.Tensor: def center_crop_image(image: torch.Tensor, output_size: List[int]) -> torch.Tensor:
crop_height, crop_width = _center_crop_parse_output_size(output_size) crop_height, crop_width = _center_crop_parse_output_size(output_size)
shape = image.shape shape = image.shape
if image.numel() == 0: if image.numel() == 0:
...@@ -1995,20 +1993,20 @@ def center_crop_image_tensor(image: torch.Tensor, output_size: List[int]) -> tor ...@@ -1995,20 +1993,20 @@ def center_crop_image_tensor(image: torch.Tensor, output_size: List[int]) -> tor
@_register_kernel_internal(center_crop, PIL.Image.Image) @_register_kernel_internal(center_crop, PIL.Image.Image)
def center_crop_image_pil(image: PIL.Image.Image, output_size: List[int]) -> PIL.Image.Image: def _center_crop_image_pil(image: PIL.Image.Image, output_size: List[int]) -> PIL.Image.Image:
crop_height, crop_width = _center_crop_parse_output_size(output_size) crop_height, crop_width = _center_crop_parse_output_size(output_size)
image_height, image_width = get_size_image_pil(image) image_height, image_width = _get_size_image_pil(image)
if crop_height > image_height or crop_width > image_width: if crop_height > image_height or crop_width > image_width:
padding_ltrb = _center_crop_compute_padding(crop_height, crop_width, image_height, image_width) padding_ltrb = _center_crop_compute_padding(crop_height, crop_width, image_height, image_width)
image = pad_image_pil(image, padding_ltrb, fill=0) image = _pad_image_pil(image, padding_ltrb, fill=0)
image_height, image_width = get_size_image_pil(image) image_height, image_width = _get_size_image_pil(image)
if crop_width == image_width and crop_height == image_height: if crop_width == image_width and crop_height == image_height:
return image return image
crop_top, crop_left = _center_crop_compute_crop_anchor(crop_height, crop_width, image_height, image_width) crop_top, crop_left = _center_crop_compute_crop_anchor(crop_height, crop_width, image_height, image_width)
return crop_image_pil(image, crop_top, crop_left, crop_height, crop_width) return _crop_image_pil(image, crop_top, crop_left, crop_height, crop_width)
def center_crop_bounding_boxes( def center_crop_bounding_boxes(
...@@ -2042,7 +2040,7 @@ def center_crop_mask(mask: torch.Tensor, output_size: List[int]) -> torch.Tensor ...@@ -2042,7 +2040,7 @@ def center_crop_mask(mask: torch.Tensor, output_size: List[int]) -> torch.Tensor
else: else:
needs_squeeze = False needs_squeeze = False
output = center_crop_image_tensor(image=mask, output_size=output_size) output = center_crop_image(image=mask, output_size=output_size)
if needs_squeeze: if needs_squeeze:
output = output.squeeze(0) output = output.squeeze(0)
...@@ -2052,7 +2050,7 @@ def center_crop_mask(mask: torch.Tensor, output_size: List[int]) -> torch.Tensor ...@@ -2052,7 +2050,7 @@ def center_crop_mask(mask: torch.Tensor, output_size: List[int]) -> torch.Tensor
@_register_kernel_internal(center_crop, datapoints.Video) @_register_kernel_internal(center_crop, datapoints.Video)
def center_crop_video(video: torch.Tensor, output_size: List[int]) -> torch.Tensor: def center_crop_video(video: torch.Tensor, output_size: List[int]) -> torch.Tensor:
return center_crop_image_tensor(video, output_size) return center_crop_image(video, output_size)
def resized_crop( def resized_crop(
...@@ -2066,7 +2064,7 @@ def resized_crop( ...@@ -2066,7 +2064,7 @@ def resized_crop(
antialias: Optional[Union[str, bool]] = "warn", antialias: Optional[Union[str, bool]] = "warn",
) -> torch.Tensor: ) -> torch.Tensor:
if torch.jit.is_scripting(): if torch.jit.is_scripting():
return resized_crop_image_tensor( return resized_crop_image(
inpt, inpt,
top=top, top=top,
left=left, left=left,
...@@ -2094,7 +2092,7 @@ def resized_crop( ...@@ -2094,7 +2092,7 @@ def resized_crop(
@_register_kernel_internal(resized_crop, torch.Tensor) @_register_kernel_internal(resized_crop, torch.Tensor)
@_register_kernel_internal(resized_crop, datapoints.Image) @_register_kernel_internal(resized_crop, datapoints.Image)
def resized_crop_image_tensor( def resized_crop_image(
image: torch.Tensor, image: torch.Tensor,
top: int, top: int,
left: int, left: int,
...@@ -2104,11 +2102,11 @@ def resized_crop_image_tensor( ...@@ -2104,11 +2102,11 @@ def resized_crop_image_tensor(
interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR, interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
antialias: Optional[Union[str, bool]] = "warn", antialias: Optional[Union[str, bool]] = "warn",
) -> torch.Tensor: ) -> torch.Tensor:
image = crop_image_tensor(image, top, left, height, width) image = crop_image(image, top, left, height, width)
return resize_image_tensor(image, size, interpolation=interpolation, antialias=antialias) return resize_image(image, size, interpolation=interpolation, antialias=antialias)
def resized_crop_image_pil( def _resized_crop_image_pil(
image: PIL.Image.Image, image: PIL.Image.Image,
top: int, top: int,
left: int, left: int,
...@@ -2117,12 +2115,12 @@ def resized_crop_image_pil( ...@@ -2117,12 +2115,12 @@ def resized_crop_image_pil(
size: List[int], size: List[int],
interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR, interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
) -> PIL.Image.Image: ) -> PIL.Image.Image:
image = crop_image_pil(image, top, left, height, width) image = _crop_image_pil(image, top, left, height, width)
return resize_image_pil(image, size, interpolation=interpolation) return _resize_image_pil(image, size, interpolation=interpolation)
@_register_kernel_internal(resized_crop, PIL.Image.Image) @_register_kernel_internal(resized_crop, PIL.Image.Image)
def resized_crop_image_pil_dispatch( def _resized_crop_image_pil_dispatch(
image: PIL.Image.Image, image: PIL.Image.Image,
top: int, top: int,
left: int, left: int,
...@@ -2134,7 +2132,7 @@ def resized_crop_image_pil_dispatch( ...@@ -2134,7 +2132,7 @@ def resized_crop_image_pil_dispatch(
) -> PIL.Image.Image: ) -> PIL.Image.Image:
if antialias is False: if antialias is False:
warnings.warn("Anti-alias option is always applied for PIL Image input. Argument antialias is ignored.") warnings.warn("Anti-alias option is always applied for PIL Image input. Argument antialias is ignored.")
return resized_crop_image_pil( return _resized_crop_image_pil(
image, image,
top=top, top=top,
left=left, left=left,
...@@ -2201,7 +2199,7 @@ def resized_crop_video( ...@@ -2201,7 +2199,7 @@ def resized_crop_video(
interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR, interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
antialias: Optional[Union[str, bool]] = "warn", antialias: Optional[Union[str, bool]] = "warn",
) -> torch.Tensor: ) -> torch.Tensor:
return resized_crop_image_tensor( return resized_crop_image(
video, top, left, height, width, antialias=antialias, size=size, interpolation=interpolation video, top, left, height, width, antialias=antialias, size=size, interpolation=interpolation
) )
...@@ -2210,7 +2208,7 @@ def five_crop( ...@@ -2210,7 +2208,7 @@ def five_crop(
inpt: torch.Tensor, size: List[int] inpt: torch.Tensor, size: List[int]
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
if torch.jit.is_scripting(): if torch.jit.is_scripting():
return five_crop_image_tensor(inpt, size=size) return five_crop_image(inpt, size=size)
_log_api_usage_once(five_crop) _log_api_usage_once(five_crop)
...@@ -2234,7 +2232,7 @@ def _parse_five_crop_size(size: List[int]) -> List[int]: ...@@ -2234,7 +2232,7 @@ def _parse_five_crop_size(size: List[int]) -> List[int]:
@_register_five_ten_crop_kernel_internal(five_crop, torch.Tensor) @_register_five_ten_crop_kernel_internal(five_crop, torch.Tensor)
@_register_five_ten_crop_kernel_internal(five_crop, datapoints.Image) @_register_five_ten_crop_kernel_internal(five_crop, datapoints.Image)
def five_crop_image_tensor( def five_crop_image(
image: torch.Tensor, size: List[int] image: torch.Tensor, size: List[int]
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
crop_height, crop_width = _parse_five_crop_size(size) crop_height, crop_width = _parse_five_crop_size(size)
...@@ -2243,30 +2241,30 @@ def five_crop_image_tensor( ...@@ -2243,30 +2241,30 @@ def five_crop_image_tensor(
if crop_width > image_width or crop_height > image_height: if crop_width > image_width or crop_height > image_height:
raise ValueError(f"Requested crop size {size} is bigger than input size {(image_height, image_width)}") raise ValueError(f"Requested crop size {size} is bigger than input size {(image_height, image_width)}")
tl = crop_image_tensor(image, 0, 0, crop_height, crop_width) tl = crop_image(image, 0, 0, crop_height, crop_width)
tr = crop_image_tensor(image, 0, image_width - crop_width, crop_height, crop_width) tr = crop_image(image, 0, image_width - crop_width, crop_height, crop_width)
bl = crop_image_tensor(image, image_height - crop_height, 0, crop_height, crop_width) bl = crop_image(image, image_height - crop_height, 0, crop_height, crop_width)
br = crop_image_tensor(image, image_height - crop_height, image_width - crop_width, crop_height, crop_width) br = crop_image(image, image_height - crop_height, image_width - crop_width, crop_height, crop_width)
center = center_crop_image_tensor(image, [crop_height, crop_width]) center = center_crop_image(image, [crop_height, crop_width])
return tl, tr, bl, br, center return tl, tr, bl, br, center
@_register_five_ten_crop_kernel_internal(five_crop, PIL.Image.Image) @_register_five_ten_crop_kernel_internal(five_crop, PIL.Image.Image)
def five_crop_image_pil( def _five_crop_image_pil(
image: PIL.Image.Image, size: List[int] image: PIL.Image.Image, size: List[int]
) -> Tuple[PIL.Image.Image, PIL.Image.Image, PIL.Image.Image, PIL.Image.Image, PIL.Image.Image]: ) -> Tuple[PIL.Image.Image, PIL.Image.Image, PIL.Image.Image, PIL.Image.Image, PIL.Image.Image]:
crop_height, crop_width = _parse_five_crop_size(size) crop_height, crop_width = _parse_five_crop_size(size)
image_height, image_width = get_size_image_pil(image) image_height, image_width = _get_size_image_pil(image)
if crop_width > image_width or crop_height > image_height: if crop_width > image_width or crop_height > image_height:
raise ValueError(f"Requested crop size {size} is bigger than input size {(image_height, image_width)}") raise ValueError(f"Requested crop size {size} is bigger than input size {(image_height, image_width)}")
tl = crop_image_pil(image, 0, 0, crop_height, crop_width) tl = _crop_image_pil(image, 0, 0, crop_height, crop_width)
tr = crop_image_pil(image, 0, image_width - crop_width, crop_height, crop_width) tr = _crop_image_pil(image, 0, image_width - crop_width, crop_height, crop_width)
bl = crop_image_pil(image, image_height - crop_height, 0, crop_height, crop_width) bl = _crop_image_pil(image, image_height - crop_height, 0, crop_height, crop_width)
br = crop_image_pil(image, image_height - crop_height, image_width - crop_width, crop_height, crop_width) br = _crop_image_pil(image, image_height - crop_height, image_width - crop_width, crop_height, crop_width)
center = center_crop_image_pil(image, [crop_height, crop_width]) center = _center_crop_image_pil(image, [crop_height, crop_width])
return tl, tr, bl, br, center return tl, tr, bl, br, center
...@@ -2275,7 +2273,7 @@ def five_crop_image_pil( ...@@ -2275,7 +2273,7 @@ def five_crop_image_pil(
def five_crop_video( def five_crop_video(
video: torch.Tensor, size: List[int] video: torch.Tensor, size: List[int]
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
return five_crop_image_tensor(video, size) return five_crop_image(video, size)
def ten_crop( def ten_crop(
...@@ -2293,7 +2291,7 @@ def ten_crop( ...@@ -2293,7 +2291,7 @@ def ten_crop(
torch.Tensor, torch.Tensor,
]: ]:
if torch.jit.is_scripting(): if torch.jit.is_scripting():
return ten_crop_image_tensor(inpt, size=size, vertical_flip=vertical_flip) return ten_crop_image(inpt, size=size, vertical_flip=vertical_flip)
_log_api_usage_once(ten_crop) _log_api_usage_once(ten_crop)
...@@ -2303,7 +2301,7 @@ def ten_crop( ...@@ -2303,7 +2301,7 @@ def ten_crop(
@_register_five_ten_crop_kernel_internal(ten_crop, torch.Tensor) @_register_five_ten_crop_kernel_internal(ten_crop, torch.Tensor)
@_register_five_ten_crop_kernel_internal(ten_crop, datapoints.Image) @_register_five_ten_crop_kernel_internal(ten_crop, datapoints.Image)
def ten_crop_image_tensor( def ten_crop_image(
image: torch.Tensor, size: List[int], vertical_flip: bool = False image: torch.Tensor, size: List[int], vertical_flip: bool = False
) -> Tuple[ ) -> Tuple[
torch.Tensor, torch.Tensor,
...@@ -2317,20 +2315,20 @@ def ten_crop_image_tensor( ...@@ -2317,20 +2315,20 @@ def ten_crop_image_tensor(
torch.Tensor, torch.Tensor,
torch.Tensor, torch.Tensor,
]: ]:
non_flipped = five_crop_image_tensor(image, size) non_flipped = five_crop_image(image, size)
if vertical_flip: if vertical_flip:
image = vertical_flip_image_tensor(image) image = vertical_flip_image(image)
else: else:
image = horizontal_flip_image_tensor(image) image = horizontal_flip_image(image)
flipped = five_crop_image_tensor(image, size) flipped = five_crop_image(image, size)
return non_flipped + flipped return non_flipped + flipped
@_register_five_ten_crop_kernel_internal(ten_crop, PIL.Image.Image) @_register_five_ten_crop_kernel_internal(ten_crop, PIL.Image.Image)
def ten_crop_image_pil( def _ten_crop_image_pil(
image: PIL.Image.Image, size: List[int], vertical_flip: bool = False image: PIL.Image.Image, size: List[int], vertical_flip: bool = False
) -> Tuple[ ) -> Tuple[
PIL.Image.Image, PIL.Image.Image,
...@@ -2344,14 +2342,14 @@ def ten_crop_image_pil( ...@@ -2344,14 +2342,14 @@ def ten_crop_image_pil(
PIL.Image.Image, PIL.Image.Image,
PIL.Image.Image, PIL.Image.Image,
]: ]:
non_flipped = five_crop_image_pil(image, size) non_flipped = _five_crop_image_pil(image, size)
if vertical_flip: if vertical_flip:
image = vertical_flip_image_pil(image) image = _vertical_flip_image_pil(image)
else: else:
image = horizontal_flip_image_pil(image) image = _horizontal_flip_image_pil(image)
flipped = five_crop_image_pil(image, size) flipped = _five_crop_image_pil(image, size)
return non_flipped + flipped return non_flipped + flipped
...@@ -2371,4 +2369,4 @@ def ten_crop_video( ...@@ -2371,4 +2369,4 @@ def ten_crop_video(
torch.Tensor, torch.Tensor,
torch.Tensor, torch.Tensor,
]: ]:
return ten_crop_image_tensor(video, size, vertical_flip=vertical_flip) return ten_crop_image(video, size, vertical_flip=vertical_flip)
...@@ -13,7 +13,7 @@ from ._utils import _get_kernel, _register_kernel_internal, is_simple_tensor ...@@ -13,7 +13,7 @@ from ._utils import _get_kernel, _register_kernel_internal, is_simple_tensor
def get_dimensions(inpt: torch.Tensor) -> List[int]: def get_dimensions(inpt: torch.Tensor) -> List[int]:
if torch.jit.is_scripting(): if torch.jit.is_scripting():
return get_dimensions_image_tensor(inpt) return get_dimensions_image(inpt)
_log_api_usage_once(get_dimensions) _log_api_usage_once(get_dimensions)
...@@ -23,7 +23,7 @@ def get_dimensions(inpt: torch.Tensor) -> List[int]: ...@@ -23,7 +23,7 @@ def get_dimensions(inpt: torch.Tensor) -> List[int]:
@_register_kernel_internal(get_dimensions, torch.Tensor) @_register_kernel_internal(get_dimensions, torch.Tensor)
@_register_kernel_internal(get_dimensions, datapoints.Image, datapoint_wrapper=False) @_register_kernel_internal(get_dimensions, datapoints.Image, datapoint_wrapper=False)
def get_dimensions_image_tensor(image: torch.Tensor) -> List[int]: def get_dimensions_image(image: torch.Tensor) -> List[int]:
chw = list(image.shape[-3:]) chw = list(image.shape[-3:])
ndims = len(chw) ndims = len(chw)
if ndims == 3: if ndims == 3:
...@@ -35,17 +35,17 @@ def get_dimensions_image_tensor(image: torch.Tensor) -> List[int]: ...@@ -35,17 +35,17 @@ def get_dimensions_image_tensor(image: torch.Tensor) -> List[int]:
raise TypeError(f"Input tensor should have at least two dimensions, but got {ndims}") raise TypeError(f"Input tensor should have at least two dimensions, but got {ndims}")
get_dimensions_image_pil = _register_kernel_internal(get_dimensions, PIL.Image.Image)(_FP.get_dimensions) _get_dimensions_image_pil = _register_kernel_internal(get_dimensions, PIL.Image.Image)(_FP.get_dimensions)
@_register_kernel_internal(get_dimensions, datapoints.Video, datapoint_wrapper=False) @_register_kernel_internal(get_dimensions, datapoints.Video, datapoint_wrapper=False)
def get_dimensions_video(video: torch.Tensor) -> List[int]: def get_dimensions_video(video: torch.Tensor) -> List[int]:
return get_dimensions_image_tensor(video) return get_dimensions_image(video)
def get_num_channels(inpt: torch.Tensor) -> int: def get_num_channels(inpt: torch.Tensor) -> int:
if torch.jit.is_scripting(): if torch.jit.is_scripting():
return get_num_channels_image_tensor(inpt) return get_num_channels_image(inpt)
_log_api_usage_once(get_num_channels) _log_api_usage_once(get_num_channels)
...@@ -55,7 +55,7 @@ def get_num_channels(inpt: torch.Tensor) -> int: ...@@ -55,7 +55,7 @@ def get_num_channels(inpt: torch.Tensor) -> int:
@_register_kernel_internal(get_num_channels, torch.Tensor) @_register_kernel_internal(get_num_channels, torch.Tensor)
@_register_kernel_internal(get_num_channels, datapoints.Image, datapoint_wrapper=False) @_register_kernel_internal(get_num_channels, datapoints.Image, datapoint_wrapper=False)
def get_num_channels_image_tensor(image: torch.Tensor) -> int: def get_num_channels_image(image: torch.Tensor) -> int:
chw = image.shape[-3:] chw = image.shape[-3:]
ndims = len(chw) ndims = len(chw)
if ndims == 3: if ndims == 3:
...@@ -66,12 +66,12 @@ def get_num_channels_image_tensor(image: torch.Tensor) -> int: ...@@ -66,12 +66,12 @@ def get_num_channels_image_tensor(image: torch.Tensor) -> int:
raise TypeError(f"Input tensor should have at least two dimensions, but got {ndims}") raise TypeError(f"Input tensor should have at least two dimensions, but got {ndims}")
get_num_channels_image_pil = _register_kernel_internal(get_num_channels, PIL.Image.Image)(_FP.get_image_num_channels) _get_num_channels_image_pil = _register_kernel_internal(get_num_channels, PIL.Image.Image)(_FP.get_image_num_channels)
@_register_kernel_internal(get_num_channels, datapoints.Video, datapoint_wrapper=False) @_register_kernel_internal(get_num_channels, datapoints.Video, datapoint_wrapper=False)
def get_num_channels_video(video: torch.Tensor) -> int: def get_num_channels_video(video: torch.Tensor) -> int:
return get_num_channels_image_tensor(video) return get_num_channels_image(video)
# We changed the names to ensure it can be used not only for images but also videos. Thus, we just alias it without # We changed the names to ensure it can be used not only for images but also videos. Thus, we just alias it without
...@@ -81,7 +81,7 @@ get_image_num_channels = get_num_channels ...@@ -81,7 +81,7 @@ get_image_num_channels = get_num_channels
def get_size(inpt: torch.Tensor) -> List[int]: def get_size(inpt: torch.Tensor) -> List[int]:
if torch.jit.is_scripting(): if torch.jit.is_scripting():
return get_size_image_tensor(inpt) return get_size_image(inpt)
_log_api_usage_once(get_size) _log_api_usage_once(get_size)
...@@ -91,7 +91,7 @@ def get_size(inpt: torch.Tensor) -> List[int]: ...@@ -91,7 +91,7 @@ def get_size(inpt: torch.Tensor) -> List[int]:
@_register_kernel_internal(get_size, torch.Tensor) @_register_kernel_internal(get_size, torch.Tensor)
@_register_kernel_internal(get_size, datapoints.Image, datapoint_wrapper=False) @_register_kernel_internal(get_size, datapoints.Image, datapoint_wrapper=False)
def get_size_image_tensor(image: torch.Tensor) -> List[int]: def get_size_image(image: torch.Tensor) -> List[int]:
hw = list(image.shape[-2:]) hw = list(image.shape[-2:])
ndims = len(hw) ndims = len(hw)
if ndims == 2: if ndims == 2:
...@@ -101,19 +101,19 @@ def get_size_image_tensor(image: torch.Tensor) -> List[int]: ...@@ -101,19 +101,19 @@ def get_size_image_tensor(image: torch.Tensor) -> List[int]:
@_register_kernel_internal(get_size, PIL.Image.Image) @_register_kernel_internal(get_size, PIL.Image.Image)
def get_size_image_pil(image: PIL.Image.Image) -> List[int]: def _get_size_image_pil(image: PIL.Image.Image) -> List[int]:
width, height = _FP.get_image_size(image) width, height = _FP.get_image_size(image)
return [height, width] return [height, width]
@_register_kernel_internal(get_size, datapoints.Video, datapoint_wrapper=False) @_register_kernel_internal(get_size, datapoints.Video, datapoint_wrapper=False)
def get_size_video(video: torch.Tensor) -> List[int]: def get_size_video(video: torch.Tensor) -> List[int]:
return get_size_image_tensor(video) return get_size_image(video)
@_register_kernel_internal(get_size, datapoints.Mask, datapoint_wrapper=False) @_register_kernel_internal(get_size, datapoints.Mask, datapoint_wrapper=False)
def get_size_mask(mask: torch.Tensor) -> List[int]: def get_size_mask(mask: torch.Tensor) -> List[int]:
return get_size_image_tensor(mask) return get_size_image(mask)
@_register_kernel_internal(get_size, datapoints.BoundingBoxes, datapoint_wrapper=False) @_register_kernel_internal(get_size, datapoints.BoundingBoxes, datapoint_wrapper=False)
......
...@@ -21,7 +21,7 @@ def normalize( ...@@ -21,7 +21,7 @@ def normalize(
inplace: bool = False, inplace: bool = False,
) -> torch.Tensor: ) -> torch.Tensor:
if torch.jit.is_scripting(): if torch.jit.is_scripting():
return normalize_image_tensor(inpt, mean=mean, std=std, inplace=inplace) return normalize_image(inpt, mean=mean, std=std, inplace=inplace)
_log_api_usage_once(normalize) _log_api_usage_once(normalize)
...@@ -31,9 +31,7 @@ def normalize( ...@@ -31,9 +31,7 @@ def normalize(
@_register_kernel_internal(normalize, torch.Tensor) @_register_kernel_internal(normalize, torch.Tensor)
@_register_kernel_internal(normalize, datapoints.Image) @_register_kernel_internal(normalize, datapoints.Image)
def normalize_image_tensor( def normalize_image(image: torch.Tensor, mean: List[float], std: List[float], inplace: bool = False) -> torch.Tensor:
image: torch.Tensor, mean: List[float], std: List[float], inplace: bool = False
) -> torch.Tensor:
if not image.is_floating_point(): if not image.is_floating_point():
raise TypeError(f"Input tensor should be a float tensor. Got {image.dtype}.") raise TypeError(f"Input tensor should be a float tensor. Got {image.dtype}.")
...@@ -68,12 +66,12 @@ def normalize_image_tensor( ...@@ -68,12 +66,12 @@ def normalize_image_tensor(
@_register_kernel_internal(normalize, datapoints.Video) @_register_kernel_internal(normalize, datapoints.Video)
def normalize_video(video: torch.Tensor, mean: List[float], std: List[float], inplace: bool = False) -> torch.Tensor: def normalize_video(video: torch.Tensor, mean: List[float], std: List[float], inplace: bool = False) -> torch.Tensor:
return normalize_image_tensor(video, mean, std, inplace=inplace) return normalize_image(video, mean, std, inplace=inplace)
def gaussian_blur(inpt: torch.Tensor, kernel_size: List[int], sigma: Optional[List[float]] = None) -> torch.Tensor: def gaussian_blur(inpt: torch.Tensor, kernel_size: List[int], sigma: Optional[List[float]] = None) -> torch.Tensor:
if torch.jit.is_scripting(): if torch.jit.is_scripting():
return gaussian_blur_image_tensor(inpt, kernel_size=kernel_size, sigma=sigma) return gaussian_blur_image(inpt, kernel_size=kernel_size, sigma=sigma)
_log_api_usage_once(gaussian_blur) _log_api_usage_once(gaussian_blur)
...@@ -99,7 +97,7 @@ def _get_gaussian_kernel2d( ...@@ -99,7 +97,7 @@ def _get_gaussian_kernel2d(
@_register_kernel_internal(gaussian_blur, torch.Tensor) @_register_kernel_internal(gaussian_blur, torch.Tensor)
@_register_kernel_internal(gaussian_blur, datapoints.Image) @_register_kernel_internal(gaussian_blur, datapoints.Image)
def gaussian_blur_image_tensor( def gaussian_blur_image(
image: torch.Tensor, kernel_size: List[int], sigma: Optional[List[float]] = None image: torch.Tensor, kernel_size: List[int], sigma: Optional[List[float]] = None
) -> torch.Tensor: ) -> torch.Tensor:
# TODO: consider deprecating integers from sigma on the future # TODO: consider deprecating integers from sigma on the future
...@@ -164,11 +162,11 @@ def gaussian_blur_image_tensor( ...@@ -164,11 +162,11 @@ def gaussian_blur_image_tensor(
@_register_kernel_internal(gaussian_blur, PIL.Image.Image) @_register_kernel_internal(gaussian_blur, PIL.Image.Image)
def gaussian_blur_image_pil( def _gaussian_blur_image_pil(
image: PIL.Image.Image, kernel_size: List[int], sigma: Optional[List[float]] = None image: PIL.Image.Image, kernel_size: List[int], sigma: Optional[List[float]] = None
) -> PIL.Image.Image: ) -> PIL.Image.Image:
t_img = pil_to_tensor(image) t_img = pil_to_tensor(image)
output = gaussian_blur_image_tensor(t_img, kernel_size=kernel_size, sigma=sigma) output = gaussian_blur_image(t_img, kernel_size=kernel_size, sigma=sigma)
return to_pil_image(output, mode=image.mode) return to_pil_image(output, mode=image.mode)
...@@ -176,12 +174,12 @@ def gaussian_blur_image_pil( ...@@ -176,12 +174,12 @@ def gaussian_blur_image_pil(
def gaussian_blur_video( def gaussian_blur_video(
video: torch.Tensor, kernel_size: List[int], sigma: Optional[List[float]] = None video: torch.Tensor, kernel_size: List[int], sigma: Optional[List[float]] = None
) -> torch.Tensor: ) -> torch.Tensor:
return gaussian_blur_image_tensor(video, kernel_size, sigma) return gaussian_blur_image(video, kernel_size, sigma)
def to_dtype(inpt: torch.Tensor, dtype: torch.dtype = torch.float, scale: bool = False) -> torch.Tensor: def to_dtype(inpt: torch.Tensor, dtype: torch.dtype = torch.float, scale: bool = False) -> torch.Tensor:
if torch.jit.is_scripting(): if torch.jit.is_scripting():
return to_dtype_image_tensor(inpt, dtype=dtype, scale=scale) return to_dtype_image(inpt, dtype=dtype, scale=scale)
_log_api_usage_once(to_dtype) _log_api_usage_once(to_dtype)
...@@ -206,7 +204,7 @@ def _num_value_bits(dtype: torch.dtype) -> int: ...@@ -206,7 +204,7 @@ def _num_value_bits(dtype: torch.dtype) -> int:
@_register_kernel_internal(to_dtype, torch.Tensor) @_register_kernel_internal(to_dtype, torch.Tensor)
@_register_kernel_internal(to_dtype, datapoints.Image) @_register_kernel_internal(to_dtype, datapoints.Image)
def to_dtype_image_tensor(image: torch.Tensor, dtype: torch.dtype = torch.float, scale: bool = False) -> torch.Tensor: def to_dtype_image(image: torch.Tensor, dtype: torch.dtype = torch.float, scale: bool = False) -> torch.Tensor:
if image.dtype == dtype: if image.dtype == dtype:
return image return image
...@@ -260,12 +258,12 @@ def to_dtype_image_tensor(image: torch.Tensor, dtype: torch.dtype = torch.float, ...@@ -260,12 +258,12 @@ def to_dtype_image_tensor(image: torch.Tensor, dtype: torch.dtype = torch.float,
# We encourage users to use to_dtype() instead but we keep this for BC # We encourage users to use to_dtype() instead but we keep this for BC
def convert_image_dtype(image: torch.Tensor, dtype: torch.dtype = torch.float32) -> torch.Tensor: def convert_image_dtype(image: torch.Tensor, dtype: torch.dtype = torch.float32) -> torch.Tensor:
return to_dtype_image_tensor(image, dtype=dtype, scale=True) return to_dtype_image(image, dtype=dtype, scale=True)
@_register_kernel_internal(to_dtype, datapoints.Video) @_register_kernel_internal(to_dtype, datapoints.Video)
def to_dtype_video(video: torch.Tensor, dtype: torch.dtype = torch.float, scale: bool = False) -> torch.Tensor: def to_dtype_video(video: torch.Tensor, dtype: torch.dtype = torch.float, scale: bool = False) -> torch.Tensor:
return to_dtype_image_tensor(video, dtype, scale=scale) return to_dtype_image(video, dtype, scale=scale)
@_register_kernel_internal(to_dtype, datapoints.BoundingBoxes, datapoint_wrapper=False) @_register_kernel_internal(to_dtype, datapoints.BoundingBoxes, datapoint_wrapper=False)
......
...@@ -8,7 +8,7 @@ from torchvision.transforms import functional as _F ...@@ -8,7 +8,7 @@ from torchvision.transforms import functional as _F
@torch.jit.unused @torch.jit.unused
def to_image_tensor(inpt: Union[torch.Tensor, PIL.Image.Image, np.ndarray]) -> datapoints.Image: def to_image(inpt: Union[torch.Tensor, PIL.Image.Image, np.ndarray]) -> datapoints.Image:
if isinstance(inpt, np.ndarray): if isinstance(inpt, np.ndarray):
output = torch.from_numpy(inpt).permute((2, 0, 1)).contiguous() output = torch.from_numpy(inpt).permute((2, 0, 1)).contiguous()
elif isinstance(inpt, PIL.Image.Image): elif isinstance(inpt, PIL.Image.Image):
...@@ -20,9 +20,5 @@ def to_image_tensor(inpt: Union[torch.Tensor, PIL.Image.Image, np.ndarray]) -> d ...@@ -20,9 +20,5 @@ def to_image_tensor(inpt: Union[torch.Tensor, PIL.Image.Image, np.ndarray]) -> d
return datapoints.Image(output) return datapoints.Image(output)
to_image_pil = _F.to_pil_image to_pil_image = _F.to_pil_image
pil_to_tensor = _F.pil_to_tensor pil_to_tensor = _F.pil_to_tensor
# We changed the names to align them with the new naming scheme. Still, `to_pil_image` is
# prevalent and well understood. Thus, we just alias it without deprecating the old name.
to_pil_image = to_image_pil
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment