Unverified Commit ca012d39 authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

make PIL kernels private (#7831)

parent cdbbd666
......@@ -10,7 +10,7 @@ from torchvision.transforms import functional as _F
def to_tensor(inpt: Any) -> torch.Tensor:
warnings.warn(
"The function `to_tensor(...)` is deprecated and will be removed in a future release. "
"Instead, please use `to_image_tensor(...)` followed by `to_dtype(..., dtype=torch.float32, scale=True)`."
"Instead, please use `to_image(...)` followed by `to_dtype(..., dtype=torch.float32, scale=True)`."
)
return _F.to_tensor(inpt)
......
......@@ -13,7 +13,7 @@ from ._utils import _get_kernel, _register_kernel_internal, is_simple_tensor
def get_dimensions(inpt: torch.Tensor) -> List[int]:
if torch.jit.is_scripting():
return get_dimensions_image_tensor(inpt)
return get_dimensions_image(inpt)
_log_api_usage_once(get_dimensions)
......@@ -23,7 +23,7 @@ def get_dimensions(inpt: torch.Tensor) -> List[int]:
@_register_kernel_internal(get_dimensions, torch.Tensor)
@_register_kernel_internal(get_dimensions, datapoints.Image, datapoint_wrapper=False)
def get_dimensions_image_tensor(image: torch.Tensor) -> List[int]:
def get_dimensions_image(image: torch.Tensor) -> List[int]:
chw = list(image.shape[-3:])
ndims = len(chw)
if ndims == 3:
......@@ -35,17 +35,17 @@ def get_dimensions_image_tensor(image: torch.Tensor) -> List[int]:
raise TypeError(f"Input tensor should have at least two dimensions, but got {ndims}")
get_dimensions_image_pil = _register_kernel_internal(get_dimensions, PIL.Image.Image)(_FP.get_dimensions)
_get_dimensions_image_pil = _register_kernel_internal(get_dimensions, PIL.Image.Image)(_FP.get_dimensions)
@_register_kernel_internal(get_dimensions, datapoints.Video, datapoint_wrapper=False)
def get_dimensions_video(video: torch.Tensor) -> List[int]:
return get_dimensions_image_tensor(video)
return get_dimensions_image(video)
def get_num_channels(inpt: torch.Tensor) -> int:
if torch.jit.is_scripting():
return get_num_channels_image_tensor(inpt)
return get_num_channels_image(inpt)
_log_api_usage_once(get_num_channels)
......@@ -55,7 +55,7 @@ def get_num_channels(inpt: torch.Tensor) -> int:
@_register_kernel_internal(get_num_channels, torch.Tensor)
@_register_kernel_internal(get_num_channels, datapoints.Image, datapoint_wrapper=False)
def get_num_channels_image_tensor(image: torch.Tensor) -> int:
def get_num_channels_image(image: torch.Tensor) -> int:
chw = image.shape[-3:]
ndims = len(chw)
if ndims == 3:
......@@ -66,12 +66,12 @@ def get_num_channels_image_tensor(image: torch.Tensor) -> int:
raise TypeError(f"Input tensor should have at least two dimensions, but got {ndims}")
get_num_channels_image_pil = _register_kernel_internal(get_num_channels, PIL.Image.Image)(_FP.get_image_num_channels)
_get_num_channels_image_pil = _register_kernel_internal(get_num_channels, PIL.Image.Image)(_FP.get_image_num_channels)
@_register_kernel_internal(get_num_channels, datapoints.Video, datapoint_wrapper=False)
def get_num_channels_video(video: torch.Tensor) -> int:
return get_num_channels_image_tensor(video)
return get_num_channels_image(video)
# We changed the names to ensure it can be used not only for images but also videos. Thus, we just alias it without
......@@ -81,7 +81,7 @@ get_image_num_channels = get_num_channels
def get_size(inpt: torch.Tensor) -> List[int]:
if torch.jit.is_scripting():
return get_size_image_tensor(inpt)
return get_size_image(inpt)
_log_api_usage_once(get_size)
......@@ -91,7 +91,7 @@ def get_size(inpt: torch.Tensor) -> List[int]:
@_register_kernel_internal(get_size, torch.Tensor)
@_register_kernel_internal(get_size, datapoints.Image, datapoint_wrapper=False)
def get_size_image_tensor(image: torch.Tensor) -> List[int]:
def get_size_image(image: torch.Tensor) -> List[int]:
hw = list(image.shape[-2:])
ndims = len(hw)
if ndims == 2:
......@@ -101,19 +101,19 @@ def get_size_image_tensor(image: torch.Tensor) -> List[int]:
@_register_kernel_internal(get_size, PIL.Image.Image)
def get_size_image_pil(image: PIL.Image.Image) -> List[int]:
def _get_size_image_pil(image: PIL.Image.Image) -> List[int]:
width, height = _FP.get_image_size(image)
return [height, width]
@_register_kernel_internal(get_size, datapoints.Video, datapoint_wrapper=False)
def get_size_video(video: torch.Tensor) -> List[int]:
return get_size_image_tensor(video)
return get_size_image(video)
@_register_kernel_internal(get_size, datapoints.Mask, datapoint_wrapper=False)
def get_size_mask(mask: torch.Tensor) -> List[int]:
return get_size_image_tensor(mask)
return get_size_image(mask)
@_register_kernel_internal(get_size, datapoints.BoundingBoxes, datapoint_wrapper=False)
......
......@@ -21,7 +21,7 @@ def normalize(
inplace: bool = False,
) -> torch.Tensor:
if torch.jit.is_scripting():
return normalize_image_tensor(inpt, mean=mean, std=std, inplace=inplace)
return normalize_image(inpt, mean=mean, std=std, inplace=inplace)
_log_api_usage_once(normalize)
......@@ -31,9 +31,7 @@ def normalize(
@_register_kernel_internal(normalize, torch.Tensor)
@_register_kernel_internal(normalize, datapoints.Image)
def normalize_image_tensor(
image: torch.Tensor, mean: List[float], std: List[float], inplace: bool = False
) -> torch.Tensor:
def normalize_image(image: torch.Tensor, mean: List[float], std: List[float], inplace: bool = False) -> torch.Tensor:
if not image.is_floating_point():
raise TypeError(f"Input tensor should be a float tensor. Got {image.dtype}.")
......@@ -68,12 +66,12 @@ def normalize_image_tensor(
@_register_kernel_internal(normalize, datapoints.Video)
def normalize_video(video: torch.Tensor, mean: List[float], std: List[float], inplace: bool = False) -> torch.Tensor:
return normalize_image_tensor(video, mean, std, inplace=inplace)
return normalize_image(video, mean, std, inplace=inplace)
def gaussian_blur(inpt: torch.Tensor, kernel_size: List[int], sigma: Optional[List[float]] = None) -> torch.Tensor:
if torch.jit.is_scripting():
return gaussian_blur_image_tensor(inpt, kernel_size=kernel_size, sigma=sigma)
return gaussian_blur_image(inpt, kernel_size=kernel_size, sigma=sigma)
_log_api_usage_once(gaussian_blur)
......@@ -99,7 +97,7 @@ def _get_gaussian_kernel2d(
@_register_kernel_internal(gaussian_blur, torch.Tensor)
@_register_kernel_internal(gaussian_blur, datapoints.Image)
def gaussian_blur_image_tensor(
def gaussian_blur_image(
image: torch.Tensor, kernel_size: List[int], sigma: Optional[List[float]] = None
) -> torch.Tensor:
# TODO: consider deprecating integers from sigma on the future
......@@ -164,11 +162,11 @@ def gaussian_blur_image_tensor(
@_register_kernel_internal(gaussian_blur, PIL.Image.Image)
def gaussian_blur_image_pil(
def _gaussian_blur_image_pil(
image: PIL.Image.Image, kernel_size: List[int], sigma: Optional[List[float]] = None
) -> PIL.Image.Image:
t_img = pil_to_tensor(image)
output = gaussian_blur_image_tensor(t_img, kernel_size=kernel_size, sigma=sigma)
output = gaussian_blur_image(t_img, kernel_size=kernel_size, sigma=sigma)
return to_pil_image(output, mode=image.mode)
......@@ -176,12 +174,12 @@ def gaussian_blur_image_pil(
def gaussian_blur_video(
video: torch.Tensor, kernel_size: List[int], sigma: Optional[List[float]] = None
) -> torch.Tensor:
return gaussian_blur_image_tensor(video, kernel_size, sigma)
return gaussian_blur_image(video, kernel_size, sigma)
def to_dtype(inpt: torch.Tensor, dtype: torch.dtype = torch.float, scale: bool = False) -> torch.Tensor:
if torch.jit.is_scripting():
return to_dtype_image_tensor(inpt, dtype=dtype, scale=scale)
return to_dtype_image(inpt, dtype=dtype, scale=scale)
_log_api_usage_once(to_dtype)
......@@ -206,7 +204,7 @@ def _num_value_bits(dtype: torch.dtype) -> int:
@_register_kernel_internal(to_dtype, torch.Tensor)
@_register_kernel_internal(to_dtype, datapoints.Image)
def to_dtype_image_tensor(image: torch.Tensor, dtype: torch.dtype = torch.float, scale: bool = False) -> torch.Tensor:
def to_dtype_image(image: torch.Tensor, dtype: torch.dtype = torch.float, scale: bool = False) -> torch.Tensor:
if image.dtype == dtype:
return image
......@@ -260,12 +258,12 @@ def to_dtype_image_tensor(image: torch.Tensor, dtype: torch.dtype = torch.float,
# We encourage users to use to_dtype() instead but we keep this for BC
def convert_image_dtype(image: torch.Tensor, dtype: torch.dtype = torch.float32) -> torch.Tensor:
return to_dtype_image_tensor(image, dtype=dtype, scale=True)
return to_dtype_image(image, dtype=dtype, scale=True)
@_register_kernel_internal(to_dtype, datapoints.Video)
def to_dtype_video(video: torch.Tensor, dtype: torch.dtype = torch.float, scale: bool = False) -> torch.Tensor:
return to_dtype_image_tensor(video, dtype, scale=scale)
return to_dtype_image(video, dtype, scale=scale)
@_register_kernel_internal(to_dtype, datapoints.BoundingBoxes, datapoint_wrapper=False)
......
......@@ -8,7 +8,7 @@ from torchvision.transforms import functional as _F
@torch.jit.unused
def to_image_tensor(inpt: Union[torch.Tensor, PIL.Image.Image, np.ndarray]) -> datapoints.Image:
def to_image(inpt: Union[torch.Tensor, PIL.Image.Image, np.ndarray]) -> datapoints.Image:
if isinstance(inpt, np.ndarray):
output = torch.from_numpy(inpt).permute((2, 0, 1)).contiguous()
elif isinstance(inpt, PIL.Image.Image):
......@@ -20,9 +20,5 @@ def to_image_tensor(inpt: Union[torch.Tensor, PIL.Image.Image, np.ndarray]) -> d
return datapoints.Image(output)
to_image_pil = _F.to_pil_image
to_pil_image = _F.to_pil_image
pil_to_tensor = _F.pil_to_tensor
# We changed the names to align them with the new naming scheme. Still, `to_pil_image` is
# prevalent and well understood. Thus, we just alias it without deprecating the old name.
to_pil_image = to_image_pil
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment