Unverified Commit 0d7807d5 authored by Vasilis Vryniotis's avatar Vasilis Vryniotis Committed by GitHub
Browse files

[prototype] Cleaning up the size dimension methods (#6828)

* Cleaning up the size dimension methods.

* Change error messages.
parent 7278abec
......@@ -7,7 +7,18 @@ from torchvision.prototype.features import BoundingBoxFormat, ColorSpace
from torchvision.transforms import functional_pil as _FP, functional_tensor as _FT
get_dimensions_image_tensor = _FT.get_dimensions
def get_dimensions_image_tensor(image: torch.Tensor) -> List[int]:
chw = list(image.shape[-3:])
ndims = len(chw)
if ndims == 3:
return chw
elif ndims == 2:
chw.insert(0, 1)
return chw
else:
raise TypeError(f"Input tensor should have at least two dimensions, but got {ndims}")
get_dimensions_image_pil = _FP.get_dimensions
......@@ -24,7 +35,17 @@ def get_dimensions(image: Union[features.ImageTypeJIT, features.VideoTypeJIT]) -
return get_dimensions_image_pil(image)
get_num_channels_image_tensor = _FT.get_image_num_channels
def get_num_channels_image_tensor(image: torch.Tensor) -> int:
chw = image.shape[-3:]
ndims = len(chw)
if ndims == 3:
return chw[0]
elif ndims == 2:
return 1
else:
raise TypeError(f"Input tensor should have at least two dimensions, but got {ndims}")
get_num_channels_image_pil = _FP.get_image_num_channels
......@@ -36,11 +57,11 @@ def get_num_channels(image: Union[features.ImageTypeJIT, features.VideoTypeJIT])
if isinstance(image, torch.Tensor) and (
torch.jit.is_scripting() or not isinstance(image, (features.Image, features.Video))
):
return _FT.get_image_num_channels(image)
return get_num_channels_image_tensor(image)
elif isinstance(image, (features.Image, features.Video)):
return image.num_channels
else:
return _FP.get_image_num_channels(image)
return get_num_channels_image_pil(image)
# We changed the names to ensure it can be used not only for images but also videos. Thus, we just alias it without
......@@ -49,8 +70,12 @@ get_image_num_channels = get_num_channels
def get_spatial_size_image_tensor(image: torch.Tensor) -> List[int]:
width, height = _FT.get_image_size(image)
return [height, width]
hw = list(image.shape[-2:])
ndims = len(hw)
if ndims == 2:
return hw
else:
raise TypeError(f"Input tensor should have at least two dimensions, but got {ndims}")
@torch.jit.unused
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment