Unverified Commit 6e203b44 authored by Vasilis Vryniotis's avatar Vasilis Vryniotis Committed by GitHub
Browse files

[prototype] Rewrite the meta dimension methods (#6722)

* Rewrite `get_dimensions`, `get_num_channels` and `get_spatial_size`

* Remove `get_chw`

* Remove comments

* Make `get_spatial_size` support non-image input

* Reduce the unnecessary use of `get_dimensions*`

* Fix linters

* Fix merge bug

* Linter

* Fix linter
parent 4c049ca3
from __future__ import annotations
from typing import Any, List, Optional, Union
from typing import Any, cast, List, Optional, Tuple, Union
import torch
from torchvision.transforms import InterpolationMode
......@@ -32,6 +32,10 @@ class Mask(_Feature):
) -> Mask:
return cls._wrap(tensor)
@property
def image_size(self) -> Tuple[int, int]:
return cast(Tuple[int, int], tuple(self.shape[-2:]))
def horizontal_flip(self) -> Mask:
output = self._F.horizontal_flip_mask(self)
return Mask.wrap_like(self, output)
......
......@@ -7,7 +7,7 @@ import torch
from torch.utils._pytree import tree_flatten, tree_unflatten
from torchvision.prototype import features
from torchvision.prototype.transforms import AutoAugmentPolicy, functional as F, InterpolationMode, Transform
from torchvision.prototype.transforms.functional._meta import get_chw
from torchvision.prototype.transforms.functional._meta import get_spatial_size
from ._utils import _isinstance, _setup_fill_arg
......@@ -278,7 +278,7 @@ class AutoAugment(_AutoAugmentBase):
sample = inputs if len(inputs) > 1 else inputs[0]
id, image_or_video = self._extract_image_or_video(sample)
_, height, width = get_chw(image_or_video)
height, width = get_spatial_size(image_or_video)
policy = self._policies[int(torch.randint(len(self._policies), ()))]
......@@ -349,7 +349,7 @@ class RandAugment(_AutoAugmentBase):
sample = inputs if len(inputs) > 1 else inputs[0]
id, image_or_video = self._extract_image_or_video(sample)
_, height, width = get_chw(image_or_video)
height, width = get_spatial_size(image_or_video)
for _ in range(self.num_ops):
transform_id, (magnitudes_fn, signed) = self._get_random_item(self._AUGMENTATION_SPACE)
......@@ -403,7 +403,7 @@ class TrivialAugmentWide(_AutoAugmentBase):
sample = inputs if len(inputs) > 1 else inputs[0]
id, image_or_video = self._extract_image_or_video(sample)
_, height, width = get_chw(image_or_video)
height, width = get_spatial_size(image_or_video)
transform_id, (magnitudes_fn, signed) = self._get_random_item(self._AUGMENTATION_SPACE)
......@@ -473,7 +473,7 @@ class AugMix(_AutoAugmentBase):
def forward(self, *inputs: Any) -> Any:
sample = inputs if len(inputs) > 1 else inputs[0]
id, orig_image_or_video = self._extract_image_or_video(sample)
_, height, width = get_chw(orig_image_or_video)
height, width = get_spatial_size(orig_image_or_video)
if isinstance(orig_image_or_video, torch.Tensor):
image_or_video = orig_image_or_video
......
......@@ -10,7 +10,7 @@ from torchvision._utils import sequence_to_str
from torchvision.prototype import features
from torchvision.prototype.features._feature import FillType
from torchvision.prototype.transforms.functional._meta import get_chw
from torchvision.prototype.transforms.functional._meta import get_dimensions
from torchvision.transforms.transforms import _check_sequence_input, _setup_angle, _setup_size # noqa: F401
from typing_extensions import Literal
......@@ -80,7 +80,7 @@ def query_bounding_box(sample: Any) -> features.BoundingBox:
def query_chw(sample: Any) -> Tuple[int, int, int]:
flat_sample, _ = tree_flatten(sample)
chws = {
get_chw(item)
tuple(get_dimensions(item))
for item in flat_sample
if isinstance(item, (features.Image, PIL.Image.Image, features.Video)) or features.is_simple_tensor(item)
}
......@@ -88,7 +88,8 @@ def query_chw(sample: Any) -> Tuple[int, int, int]:
raise TypeError("No image or video was found in the sample")
elif len(chws) > 1:
raise ValueError(f"Found multiple CxHxW dimensions in the sample: {sequence_to_str(sorted(chws))}")
return chws.pop()
c, h, w = chws.pop()
return c, h, w
def _isinstance(obj: Any, types_or_checks: Tuple[Union[Type, Callable[[Any], bool]], ...]) -> bool:
......
......@@ -8,9 +8,15 @@ from ._meta import (
convert_color_space_image_pil,
convert_color_space_video,
convert_color_space,
get_dimensions_image_tensor,
get_dimensions_image_pil,
get_dimensions,
get_image_num_channels,
get_num_channels_image_tensor,
get_num_channels_image_pil,
get_num_channels,
get_spatial_size_image_tensor,
get_spatial_size_image_pil,
get_spatial_size,
) # usort: skip
......
......@@ -21,7 +21,12 @@ from torchvision.transforms.functional_tensor import (
interpolate,
)
from ._meta import convert_format_bounding_box, get_dimensions_image_pil, get_dimensions_image_tensor
from ._meta import (
convert_format_bounding_box,
get_dimensions_image_tensor,
get_spatial_size_image_pil,
get_spatial_size_image_tensor,
)
horizontal_flip_image_tensor = _FT.hflip
horizontal_flip_image_pil = _FP.hflip
......@@ -323,7 +328,7 @@ def affine_image_pil(
# it is visually better to estimate the center without 0.5 offset
# otherwise image rotated by 90 degrees is shifted vs output image of torch.rot90 or F_t.affine
if center is None:
_, height, width = get_dimensions_image_pil(image)
height, width = get_spatial_size_image_pil(image)
center = [width * 0.5, height * 0.5]
matrix = _get_inverse_affine_matrix(center, angle, translate, scale, shear)
......@@ -1189,13 +1194,13 @@ def _center_crop_compute_crop_anchor(
def center_crop_image_tensor(image: torch.Tensor, output_size: List[int]) -> torch.Tensor:
crop_height, crop_width = _center_crop_parse_output_size(output_size)
_, image_height, image_width = get_dimensions_image_tensor(image)
image_height, image_width = get_spatial_size_image_tensor(image)
if crop_height > image_height or crop_width > image_width:
padding_ltrb = _center_crop_compute_padding(crop_height, crop_width, image_height, image_width)
image = pad_image_tensor(image, padding_ltrb, fill=0)
_, image_height, image_width = get_dimensions_image_tensor(image)
image_height, image_width = get_spatial_size_image_tensor(image)
if crop_width == image_width and crop_height == image_height:
return image
......@@ -1206,13 +1211,13 @@ def center_crop_image_tensor(image: torch.Tensor, output_size: List[int]) -> tor
@torch.jit.unused
def center_crop_image_pil(image: PIL.Image.Image, output_size: List[int]) -> PIL.Image.Image:
crop_height, crop_width = _center_crop_parse_output_size(output_size)
_, image_height, image_width = get_dimensions_image_pil(image)
image_height, image_width = get_spatial_size_image_pil(image)
if crop_height > image_height or crop_width > image_width:
padding_ltrb = _center_crop_compute_padding(crop_height, crop_width, image_height, image_width)
image = pad_image_pil(image, padding_ltrb, fill=0)
_, image_height, image_width = get_dimensions_image_pil(image)
image_height, image_width = get_spatial_size_image_pil(image)
if crop_width == image_width and crop_height == image_height:
return image
......@@ -1365,7 +1370,7 @@ def five_crop_image_tensor(
image: torch.Tensor, size: List[int]
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
crop_height, crop_width = _parse_five_crop_size(size)
_, image_height, image_width = get_dimensions_image_tensor(image)
image_height, image_width = get_spatial_size_image_tensor(image)
if crop_width > image_width or crop_height > image_height:
msg = "Requested crop size {} is bigger than input size {}"
......@@ -1385,7 +1390,7 @@ def five_crop_image_pil(
image: PIL.Image.Image, size: List[int]
) -> Tuple[PIL.Image.Image, PIL.Image.Image, PIL.Image.Image, PIL.Image.Image, PIL.Image.Image]:
crop_height, crop_width = _parse_five_crop_size(size)
_, image_height, image_width = get_dimensions_image_pil(image)
image_height, image_width = get_spatial_size_image_pil(image)
if crop_width > image_width or crop_height > image_height:
msg = "Requested crop size {} is bigger than input size {}"
......
......@@ -6,38 +6,37 @@ from torchvision.prototype import features
from torchvision.prototype.features import BoundingBoxFormat, ColorSpace
from torchvision.transforms import functional_pil as _FP, functional_tensor as _FT
get_dimensions_image_tensor = _FT.get_dimensions
get_dimensions_image_pil = _FP.get_dimensions
# TODO: Should this be prefixed with `_` similar to other methods that don't get exposed by init?
def get_chw(image: features.ImageOrVideoTypeJIT) -> Tuple[int, int, int]:
def get_dimensions(image: features.ImageOrVideoTypeJIT) -> List[int]:
if isinstance(image, torch.Tensor) and (
torch.jit.is_scripting() or not isinstance(image, (features.Image, features.Video))
):
channels, height, width = get_dimensions_image_tensor(image)
return get_dimensions_image_tensor(image)
elif isinstance(image, (features.Image, features.Video)):
channels = image.num_channels
height, width = image.image_size
else: # isinstance(image, PIL.Image.Image)
channels, height, width = get_dimensions_image_pil(image)
return channels, height, width
# The three functions below are here for BC. Whether we want to have two different kernels and how they and the
# compound version should be named is still under discussion: https://github.com/pytorch/vision/issues/6491
# Given that these kernels should also support boxes, masks, and videos, it is unlikely that there name will stay.
# They will either be deprecated or simply aliased to the new kernels if we have reached consensus about the issue
# detailed above.
return [channels, height, width]
else:
return get_dimensions_image_pil(image)
def get_dimensions(image: features.ImageOrVideoTypeJIT) -> List[int]:
return list(get_chw(image))
get_num_channels_image_tensor = _FT.get_image_num_channels
get_num_channels_image_pil = _FP.get_image_num_channels
def get_num_channels(image: features.ImageOrVideoTypeJIT) -> int:
num_channels, *_ = get_chw(image)
return num_channels
if isinstance(image, torch.Tensor) and (
torch.jit.is_scripting() or not isinstance(image, (features.Image, features.Video))
):
return _FT.get_image_num_channels(image)
elif isinstance(image, (features.Image, features.Video)):
return image.num_channels
else:
return _FP.get_image_num_channels(image)
# We changed the names to ensure it can be used not only for images but also videos. Thus, we just alias it without
......@@ -45,9 +44,28 @@ def get_num_channels(image: features.ImageOrVideoTypeJIT) -> int:
get_image_num_channels = get_num_channels
def get_spatial_size(image: features.ImageOrVideoTypeJIT) -> List[int]:
_, *size = get_chw(image)
return size
def get_spatial_size_image_tensor(image: torch.Tensor) -> List[int]:
width, height = _FT.get_image_size(image)
return [height, width]
@torch.jit.unused
def get_spatial_size_image_pil(image: PIL.Image.Image) -> List[int]:
width, height = _FP.get_image_size(image)
return [height, width]
def get_spatial_size(inpt: features.InputTypeJIT) -> List[int]:
if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, features._Feature)):
return get_spatial_size_image_tensor(inpt)
elif isinstance(inpt, features._Feature):
image_size = getattr(inpt, "image_size", None)
if image_size is not None:
return list(image_size)
else:
raise ValueError(f"Type {inpt.__class__} doesn't have spatial size.")
else:
return get_spatial_size_image_pil(inpt)
def _xywh_to_xyxy(xywh: torch.Tensor) -> torch.Tensor:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment