Unverified Commit 24890d71 authored by Vasilis Vryniotis's avatar Vasilis Vryniotis Committed by GitHub
Browse files

Fix issues with `get_image_size()` (#6581)



* Fix bug on `get_image_size()` and move it to deprecated. Introduce generic named spatial/channel equivalents.

* Update tests and fix mypy issues.

* Remove the use of get_image_size from ElasticTransform.

* Fix linter

* Apply suggestions from code review.

* Update torchvision/prototype/transforms/functional/_deprecated.py
Co-authored-by: default avatarPhilip Meier <github.pmeier@posteo.de>

* Further changes from code review.

* Fix linter
Co-authored-by: default avatarPhilip Meier <github.pmeier@posteo.de>
parent 6b2e0a08
...@@ -531,6 +531,8 @@ def erase_image_tensor(): ...@@ -531,6 +531,8 @@ def erase_image_tensor():
and name and name
not in { not in {
"to_image_tensor", "to_image_tensor",
"get_num_channels",
"get_spatial_size",
"get_image_num_channels", "get_image_num_channels",
"get_image_size", "get_image_size",
} }
......
...@@ -9,7 +9,8 @@ from ._meta import ( ...@@ -9,7 +9,8 @@ from ._meta import (
convert_color_space, convert_color_space,
get_dimensions, get_dimensions,
get_image_num_channels, get_image_num_channels,
get_image_size, get_num_channels,
get_spatial_size,
) # usort: skip ) # usort: skip
from ._augment import erase, erase_image_pil, erase_image_tensor from ._augment import erase, erase_image_pil, erase_image_tensor
...@@ -125,4 +126,4 @@ from ._type_conversion import ( ...@@ -125,4 +126,4 @@ from ._type_conversion import (
to_pil_image, to_pil_image,
) )
from ._deprecated import rgb_to_grayscale, to_grayscale # usort: skip from ._deprecated import get_image_size, rgb_to_grayscale, to_grayscale, to_tensor # usort: skip
import warnings import warnings
from typing import Any, Union from typing import Any, List, Union
import PIL.Image import PIL.Image
import torch import torch
...@@ -50,3 +50,11 @@ def to_tensor(inpt: Any) -> torch.Tensor: ...@@ -50,3 +50,11 @@ def to_tensor(inpt: Any) -> torch.Tensor:
"Instead, please use `to_image_tensor(...)` followed by `convert_image_dtype(...)`." "Instead, please use `to_image_tensor(...)` followed by `convert_image_dtype(...)`."
) )
return _F.to_tensor(inpt) return _F.to_tensor(inpt)
def get_image_size(inpt: Union[PIL.Image.Image, torch.Tensor, features.Image]) -> List[int]:
warnings.warn(
"The function `get_image_size(...)` is deprecated and will be removed in a future release. "
"Instead, please use `get_spatial_size(...)` which returns `[h, w]` instead of `[w, h]`."
)
return _F.get_image_size(inpt)
...@@ -34,14 +34,19 @@ def get_dimensions(image: Union[PIL.Image.Image, torch.Tensor, features.Image]) ...@@ -34,14 +34,19 @@ def get_dimensions(image: Union[PIL.Image.Image, torch.Tensor, features.Image])
return list(get_chw(image)) return list(get_chw(image))
def get_image_num_channels(image: Union[PIL.Image.Image, torch.Tensor, features.Image]) -> int: def get_num_channels(image: Union[PIL.Image.Image, torch.Tensor, features.Image]) -> int:
num_channels, *_ = get_chw(image) num_channels, *_ = get_chw(image)
return num_channels return num_channels
def get_image_size(image: Union[PIL.Image.Image, torch.Tensor, features.Image]) -> List[int]: # We changed the names to ensure it can be used not only for images but also videos. Thus, we just alias it without
_, *image_size = get_chw(image) # deprecating the old names.
return image_size get_image_num_channels = get_num_channels
def get_spatial_size(image: Union[PIL.Image.Image, torch.Tensor, features.Image]) -> List[int]:
_, *size = get_chw(image)
return size
def _xywh_to_xyxy(xywh: torch.Tensor) -> torch.Tensor: def _xywh_to_xyxy(xywh: torch.Tensor) -> torch.Tensor:
......
...@@ -2162,8 +2162,8 @@ class ElasticTransform(torch.nn.Module): ...@@ -2162,8 +2162,8 @@ class ElasticTransform(torch.nn.Module):
Returns: Returns:
PIL Image or Tensor: Transformed image. PIL Image or Tensor: Transformed image.
""" """
size = F.get_image_size(tensor)[::-1] _, height, width = F.get_dimensions(tensor)
displacement = self.get_params(self.alpha, self.sigma, size) displacement = self.get_params(self.alpha, self.sigma, [height, width])
return F.elastic_transform(tensor, displacement, self.interpolation, self.fill) return F.elastic_transform(tensor, displacement, self.interpolation, self.fill)
def __repr__(self): def __repr__(self):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment