Unverified Commit 0d7807d5 authored by Vasilis Vryniotis's avatar Vasilis Vryniotis Committed by GitHub
Browse files

[prototype] Cleaning up the size dimension methods (#6828)

* Cleaning up the size dimension methods.

* Change error messages.
parent 7278abec
...@@ -7,7 +7,18 @@ from torchvision.prototype.features import BoundingBoxFormat, ColorSpace ...@@ -7,7 +7,18 @@ from torchvision.prototype.features import BoundingBoxFormat, ColorSpace
from torchvision.transforms import functional_pil as _FP, functional_tensor as _FT from torchvision.transforms import functional_pil as _FP, functional_tensor as _FT
get_dimensions_image_tensor = _FT.get_dimensions def get_dimensions_image_tensor(image: torch.Tensor) -> List[int]:
chw = list(image.shape[-3:])
ndims = len(chw)
if ndims == 3:
return chw
elif ndims == 2:
chw.insert(0, 1)
return chw
else:
raise TypeError(f"Input tensor should have at least two dimensions, but got {ndims}")
get_dimensions_image_pil = _FP.get_dimensions get_dimensions_image_pil = _FP.get_dimensions
...@@ -24,7 +35,17 @@ def get_dimensions(image: Union[features.ImageTypeJIT, features.VideoTypeJIT]) - ...@@ -24,7 +35,17 @@ def get_dimensions(image: Union[features.ImageTypeJIT, features.VideoTypeJIT]) -
return get_dimensions_image_pil(image) return get_dimensions_image_pil(image)
get_num_channels_image_tensor = _FT.get_image_num_channels def get_num_channels_image_tensor(image: torch.Tensor) -> int:
chw = image.shape[-3:]
ndims = len(chw)
if ndims == 3:
return chw[0]
elif ndims == 2:
return 1
else:
raise TypeError(f"Input tensor should have at least two dimensions, but got {ndims}")
get_num_channels_image_pil = _FP.get_image_num_channels get_num_channels_image_pil = _FP.get_image_num_channels
...@@ -36,11 +57,11 @@ def get_num_channels(image: Union[features.ImageTypeJIT, features.VideoTypeJIT]) ...@@ -36,11 +57,11 @@ def get_num_channels(image: Union[features.ImageTypeJIT, features.VideoTypeJIT])
if isinstance(image, torch.Tensor) and ( if isinstance(image, torch.Tensor) and (
torch.jit.is_scripting() or not isinstance(image, (features.Image, features.Video)) torch.jit.is_scripting() or not isinstance(image, (features.Image, features.Video))
): ):
return _FT.get_image_num_channels(image) return get_num_channels_image_tensor(image)
elif isinstance(image, (features.Image, features.Video)): elif isinstance(image, (features.Image, features.Video)):
return image.num_channels return image.num_channels
else: else:
return _FP.get_image_num_channels(image) return get_num_channels_image_pil(image)
# We changed the names to ensure it can be used not only for images but also videos. Thus, we just alias it without # We changed the names to ensure it can be used not only for images but also videos. Thus, we just alias it without
...@@ -49,8 +70,12 @@ get_image_num_channels = get_num_channels ...@@ -49,8 +70,12 @@ get_image_num_channels = get_num_channels
def get_spatial_size_image_tensor(image: torch.Tensor) -> List[int]: def get_spatial_size_image_tensor(image: torch.Tensor) -> List[int]:
width, height = _FT.get_image_size(image) hw = list(image.shape[-2:])
return [height, width] ndims = len(hw)
if ndims == 2:
return hw
else:
raise TypeError(f"Input tensor should have at least two dimensions, but got {ndims}")
@torch.jit.unused @torch.jit.unused
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment