Unverified Commit 6e203b44 authored by Vasilis Vryniotis's avatar Vasilis Vryniotis Committed by GitHub
Browse files

[prototype] Rewrite the meta dimension methods (#6722)

* Rewrite `get_dimensions`, `get_num_channels` and `get_spatial_size`

* Remove `get_chw`

* Remove comments

* Make `get_spatial_size` support non-image input

* Reduce the unnecessary use of `get_dimensions*`

* Fix linters

* Fix merge bug

* Linter

* Fix linter
parent 4c049ca3
from __future__ import annotations from __future__ import annotations
from typing import Any, List, Optional, Union from typing import Any, cast, List, Optional, Tuple, Union
import torch import torch
from torchvision.transforms import InterpolationMode from torchvision.transforms import InterpolationMode
...@@ -32,6 +32,10 @@ class Mask(_Feature): ...@@ -32,6 +32,10 @@ class Mask(_Feature):
) -> Mask: ) -> Mask:
return cls._wrap(tensor) return cls._wrap(tensor)
@property
def image_size(self) -> Tuple[int, int]:
return cast(Tuple[int, int], tuple(self.shape[-2:]))
def horizontal_flip(self) -> Mask: def horizontal_flip(self) -> Mask:
output = self._F.horizontal_flip_mask(self) output = self._F.horizontal_flip_mask(self)
return Mask.wrap_like(self, output) return Mask.wrap_like(self, output)
......
...@@ -7,7 +7,7 @@ import torch ...@@ -7,7 +7,7 @@ import torch
from torch.utils._pytree import tree_flatten, tree_unflatten from torch.utils._pytree import tree_flatten, tree_unflatten
from torchvision.prototype import features from torchvision.prototype import features
from torchvision.prototype.transforms import AutoAugmentPolicy, functional as F, InterpolationMode, Transform from torchvision.prototype.transforms import AutoAugmentPolicy, functional as F, InterpolationMode, Transform
from torchvision.prototype.transforms.functional._meta import get_chw from torchvision.prototype.transforms.functional._meta import get_spatial_size
from ._utils import _isinstance, _setup_fill_arg from ._utils import _isinstance, _setup_fill_arg
...@@ -278,7 +278,7 @@ class AutoAugment(_AutoAugmentBase): ...@@ -278,7 +278,7 @@ class AutoAugment(_AutoAugmentBase):
sample = inputs if len(inputs) > 1 else inputs[0] sample = inputs if len(inputs) > 1 else inputs[0]
id, image_or_video = self._extract_image_or_video(sample) id, image_or_video = self._extract_image_or_video(sample)
_, height, width = get_chw(image_or_video) height, width = get_spatial_size(image_or_video)
policy = self._policies[int(torch.randint(len(self._policies), ()))] policy = self._policies[int(torch.randint(len(self._policies), ()))]
...@@ -349,7 +349,7 @@ class RandAugment(_AutoAugmentBase): ...@@ -349,7 +349,7 @@ class RandAugment(_AutoAugmentBase):
sample = inputs if len(inputs) > 1 else inputs[0] sample = inputs if len(inputs) > 1 else inputs[0]
id, image_or_video = self._extract_image_or_video(sample) id, image_or_video = self._extract_image_or_video(sample)
_, height, width = get_chw(image_or_video) height, width = get_spatial_size(image_or_video)
for _ in range(self.num_ops): for _ in range(self.num_ops):
transform_id, (magnitudes_fn, signed) = self._get_random_item(self._AUGMENTATION_SPACE) transform_id, (magnitudes_fn, signed) = self._get_random_item(self._AUGMENTATION_SPACE)
...@@ -403,7 +403,7 @@ class TrivialAugmentWide(_AutoAugmentBase): ...@@ -403,7 +403,7 @@ class TrivialAugmentWide(_AutoAugmentBase):
sample = inputs if len(inputs) > 1 else inputs[0] sample = inputs if len(inputs) > 1 else inputs[0]
id, image_or_video = self._extract_image_or_video(sample) id, image_or_video = self._extract_image_or_video(sample)
_, height, width = get_chw(image_or_video) height, width = get_spatial_size(image_or_video)
transform_id, (magnitudes_fn, signed) = self._get_random_item(self._AUGMENTATION_SPACE) transform_id, (magnitudes_fn, signed) = self._get_random_item(self._AUGMENTATION_SPACE)
...@@ -473,7 +473,7 @@ class AugMix(_AutoAugmentBase): ...@@ -473,7 +473,7 @@ class AugMix(_AutoAugmentBase):
def forward(self, *inputs: Any) -> Any: def forward(self, *inputs: Any) -> Any:
sample = inputs if len(inputs) > 1 else inputs[0] sample = inputs if len(inputs) > 1 else inputs[0]
id, orig_image_or_video = self._extract_image_or_video(sample) id, orig_image_or_video = self._extract_image_or_video(sample)
_, height, width = get_chw(orig_image_or_video) height, width = get_spatial_size(orig_image_or_video)
if isinstance(orig_image_or_video, torch.Tensor): if isinstance(orig_image_or_video, torch.Tensor):
image_or_video = orig_image_or_video image_or_video = orig_image_or_video
......
...@@ -10,7 +10,7 @@ from torchvision._utils import sequence_to_str ...@@ -10,7 +10,7 @@ from torchvision._utils import sequence_to_str
from torchvision.prototype import features from torchvision.prototype import features
from torchvision.prototype.features._feature import FillType from torchvision.prototype.features._feature import FillType
from torchvision.prototype.transforms.functional._meta import get_chw from torchvision.prototype.transforms.functional._meta import get_dimensions
from torchvision.transforms.transforms import _check_sequence_input, _setup_angle, _setup_size # noqa: F401 from torchvision.transforms.transforms import _check_sequence_input, _setup_angle, _setup_size # noqa: F401
from typing_extensions import Literal from typing_extensions import Literal
...@@ -80,7 +80,7 @@ def query_bounding_box(sample: Any) -> features.BoundingBox: ...@@ -80,7 +80,7 @@ def query_bounding_box(sample: Any) -> features.BoundingBox:
def query_chw(sample: Any) -> Tuple[int, int, int]: def query_chw(sample: Any) -> Tuple[int, int, int]:
flat_sample, _ = tree_flatten(sample) flat_sample, _ = tree_flatten(sample)
chws = { chws = {
get_chw(item) tuple(get_dimensions(item))
for item in flat_sample for item in flat_sample
if isinstance(item, (features.Image, PIL.Image.Image, features.Video)) or features.is_simple_tensor(item) if isinstance(item, (features.Image, PIL.Image.Image, features.Video)) or features.is_simple_tensor(item)
} }
...@@ -88,7 +88,8 @@ def query_chw(sample: Any) -> Tuple[int, int, int]: ...@@ -88,7 +88,8 @@ def query_chw(sample: Any) -> Tuple[int, int, int]:
raise TypeError("No image or video was found in the sample") raise TypeError("No image or video was found in the sample")
elif len(chws) > 1: elif len(chws) > 1:
raise ValueError(f"Found multiple CxHxW dimensions in the sample: {sequence_to_str(sorted(chws))}") raise ValueError(f"Found multiple CxHxW dimensions in the sample: {sequence_to_str(sorted(chws))}")
return chws.pop() c, h, w = chws.pop()
return c, h, w
def _isinstance(obj: Any, types_or_checks: Tuple[Union[Type, Callable[[Any], bool]], ...]) -> bool: def _isinstance(obj: Any, types_or_checks: Tuple[Union[Type, Callable[[Any], bool]], ...]) -> bool:
......
...@@ -8,9 +8,15 @@ from ._meta import ( ...@@ -8,9 +8,15 @@ from ._meta import (
convert_color_space_image_pil, convert_color_space_image_pil,
convert_color_space_video, convert_color_space_video,
convert_color_space, convert_color_space,
get_dimensions_image_tensor,
get_dimensions_image_pil,
get_dimensions, get_dimensions,
get_image_num_channels, get_image_num_channels,
get_num_channels_image_tensor,
get_num_channels_image_pil,
get_num_channels, get_num_channels,
get_spatial_size_image_tensor,
get_spatial_size_image_pil,
get_spatial_size, get_spatial_size,
) # usort: skip ) # usort: skip
......
...@@ -21,7 +21,12 @@ from torchvision.transforms.functional_tensor import ( ...@@ -21,7 +21,12 @@ from torchvision.transforms.functional_tensor import (
interpolate, interpolate,
) )
from ._meta import convert_format_bounding_box, get_dimensions_image_pil, get_dimensions_image_tensor from ._meta import (
convert_format_bounding_box,
get_dimensions_image_tensor,
get_spatial_size_image_pil,
get_spatial_size_image_tensor,
)
horizontal_flip_image_tensor = _FT.hflip horizontal_flip_image_tensor = _FT.hflip
horizontal_flip_image_pil = _FP.hflip horizontal_flip_image_pil = _FP.hflip
...@@ -323,7 +328,7 @@ def affine_image_pil( ...@@ -323,7 +328,7 @@ def affine_image_pil(
# it is visually better to estimate the center without 0.5 offset # it is visually better to estimate the center without 0.5 offset
# otherwise image rotated by 90 degrees is shifted vs output image of torch.rot90 or F_t.affine # otherwise image rotated by 90 degrees is shifted vs output image of torch.rot90 or F_t.affine
if center is None: if center is None:
_, height, width = get_dimensions_image_pil(image) height, width = get_spatial_size_image_pil(image)
center = [width * 0.5, height * 0.5] center = [width * 0.5, height * 0.5]
matrix = _get_inverse_affine_matrix(center, angle, translate, scale, shear) matrix = _get_inverse_affine_matrix(center, angle, translate, scale, shear)
...@@ -1189,13 +1194,13 @@ def _center_crop_compute_crop_anchor( ...@@ -1189,13 +1194,13 @@ def _center_crop_compute_crop_anchor(
def center_crop_image_tensor(image: torch.Tensor, output_size: List[int]) -> torch.Tensor: def center_crop_image_tensor(image: torch.Tensor, output_size: List[int]) -> torch.Tensor:
crop_height, crop_width = _center_crop_parse_output_size(output_size) crop_height, crop_width = _center_crop_parse_output_size(output_size)
_, image_height, image_width = get_dimensions_image_tensor(image) image_height, image_width = get_spatial_size_image_tensor(image)
if crop_height > image_height or crop_width > image_width: if crop_height > image_height or crop_width > image_width:
padding_ltrb = _center_crop_compute_padding(crop_height, crop_width, image_height, image_width) padding_ltrb = _center_crop_compute_padding(crop_height, crop_width, image_height, image_width)
image = pad_image_tensor(image, padding_ltrb, fill=0) image = pad_image_tensor(image, padding_ltrb, fill=0)
_, image_height, image_width = get_dimensions_image_tensor(image) image_height, image_width = get_spatial_size_image_tensor(image)
if crop_width == image_width and crop_height == image_height: if crop_width == image_width and crop_height == image_height:
return image return image
...@@ -1206,13 +1211,13 @@ def center_crop_image_tensor(image: torch.Tensor, output_size: List[int]) -> tor ...@@ -1206,13 +1211,13 @@ def center_crop_image_tensor(image: torch.Tensor, output_size: List[int]) -> tor
@torch.jit.unused @torch.jit.unused
def center_crop_image_pil(image: PIL.Image.Image, output_size: List[int]) -> PIL.Image.Image: def center_crop_image_pil(image: PIL.Image.Image, output_size: List[int]) -> PIL.Image.Image:
crop_height, crop_width = _center_crop_parse_output_size(output_size) crop_height, crop_width = _center_crop_parse_output_size(output_size)
_, image_height, image_width = get_dimensions_image_pil(image) image_height, image_width = get_spatial_size_image_pil(image)
if crop_height > image_height or crop_width > image_width: if crop_height > image_height or crop_width > image_width:
padding_ltrb = _center_crop_compute_padding(crop_height, crop_width, image_height, image_width) padding_ltrb = _center_crop_compute_padding(crop_height, crop_width, image_height, image_width)
image = pad_image_pil(image, padding_ltrb, fill=0) image = pad_image_pil(image, padding_ltrb, fill=0)
_, image_height, image_width = get_dimensions_image_pil(image) image_height, image_width = get_spatial_size_image_pil(image)
if crop_width == image_width and crop_height == image_height: if crop_width == image_width and crop_height == image_height:
return image return image
...@@ -1365,7 +1370,7 @@ def five_crop_image_tensor( ...@@ -1365,7 +1370,7 @@ def five_crop_image_tensor(
image: torch.Tensor, size: List[int] image: torch.Tensor, size: List[int]
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
crop_height, crop_width = _parse_five_crop_size(size) crop_height, crop_width = _parse_five_crop_size(size)
_, image_height, image_width = get_dimensions_image_tensor(image) image_height, image_width = get_spatial_size_image_tensor(image)
if crop_width > image_width or crop_height > image_height: if crop_width > image_width or crop_height > image_height:
msg = "Requested crop size {} is bigger than input size {}" msg = "Requested crop size {} is bigger than input size {}"
...@@ -1385,7 +1390,7 @@ def five_crop_image_pil( ...@@ -1385,7 +1390,7 @@ def five_crop_image_pil(
image: PIL.Image.Image, size: List[int] image: PIL.Image.Image, size: List[int]
) -> Tuple[PIL.Image.Image, PIL.Image.Image, PIL.Image.Image, PIL.Image.Image, PIL.Image.Image]: ) -> Tuple[PIL.Image.Image, PIL.Image.Image, PIL.Image.Image, PIL.Image.Image, PIL.Image.Image]:
crop_height, crop_width = _parse_five_crop_size(size) crop_height, crop_width = _parse_five_crop_size(size)
_, image_height, image_width = get_dimensions_image_pil(image) image_height, image_width = get_spatial_size_image_pil(image)
if crop_width > image_width or crop_height > image_height: if crop_width > image_width or crop_height > image_height:
msg = "Requested crop size {} is bigger than input size {}" msg = "Requested crop size {} is bigger than input size {}"
......
...@@ -6,38 +6,37 @@ from torchvision.prototype import features ...@@ -6,38 +6,37 @@ from torchvision.prototype import features
from torchvision.prototype.features import BoundingBoxFormat, ColorSpace from torchvision.prototype.features import BoundingBoxFormat, ColorSpace
from torchvision.transforms import functional_pil as _FP, functional_tensor as _FT from torchvision.transforms import functional_pil as _FP, functional_tensor as _FT
get_dimensions_image_tensor = _FT.get_dimensions get_dimensions_image_tensor = _FT.get_dimensions
get_dimensions_image_pil = _FP.get_dimensions get_dimensions_image_pil = _FP.get_dimensions
# TODO: Should this be prefixed with `_` similar to other methods that don't get exposed by init? def get_dimensions(image: features.ImageOrVideoTypeJIT) -> List[int]:
def get_chw(image: features.ImageOrVideoTypeJIT) -> Tuple[int, int, int]:
if isinstance(image, torch.Tensor) and ( if isinstance(image, torch.Tensor) and (
torch.jit.is_scripting() or not isinstance(image, (features.Image, features.Video)) torch.jit.is_scripting() or not isinstance(image, (features.Image, features.Video))
): ):
channels, height, width = get_dimensions_image_tensor(image) return get_dimensions_image_tensor(image)
elif isinstance(image, (features.Image, features.Video)): elif isinstance(image, (features.Image, features.Video)):
channels = image.num_channels channels = image.num_channels
height, width = image.image_size height, width = image.image_size
else: # isinstance(image, PIL.Image.Image) return [channels, height, width]
channels, height, width = get_dimensions_image_pil(image) else:
return channels, height, width return get_dimensions_image_pil(image)
# The three functions below are here for BC. Whether we want to have two different kernels and how they and the
# compound version should be named is still under discussion: https://github.com/pytorch/vision/issues/6491
# Given that these kernels should also support boxes, masks, and videos, it is unlikely that there name will stay.
# They will either be deprecated or simply aliased to the new kernels if we have reached consensus about the issue
# detailed above.
def get_dimensions(image: features.ImageOrVideoTypeJIT) -> List[int]: get_num_channels_image_tensor = _FT.get_image_num_channels
return list(get_chw(image)) get_num_channels_image_pil = _FP.get_image_num_channels
def get_num_channels(image: features.ImageOrVideoTypeJIT) -> int: def get_num_channels(image: features.ImageOrVideoTypeJIT) -> int:
num_channels, *_ = get_chw(image) if isinstance(image, torch.Tensor) and (
return num_channels torch.jit.is_scripting() or not isinstance(image, (features.Image, features.Video))
):
return _FT.get_image_num_channels(image)
elif isinstance(image, (features.Image, features.Video)):
return image.num_channels
else:
return _FP.get_image_num_channels(image)
# We changed the names to ensure it can be used not only for images but also videos. Thus, we just alias it without # We changed the names to ensure it can be used not only for images but also videos. Thus, we just alias it without
...@@ -45,9 +44,28 @@ def get_num_channels(image: features.ImageOrVideoTypeJIT) -> int: ...@@ -45,9 +44,28 @@ def get_num_channels(image: features.ImageOrVideoTypeJIT) -> int:
get_image_num_channels = get_num_channels get_image_num_channels = get_num_channels
def get_spatial_size(image: features.ImageOrVideoTypeJIT) -> List[int]: def get_spatial_size_image_tensor(image: torch.Tensor) -> List[int]:
_, *size = get_chw(image) width, height = _FT.get_image_size(image)
return size return [height, width]
@torch.jit.unused
def get_spatial_size_image_pil(image: PIL.Image.Image) -> List[int]:
width, height = _FP.get_image_size(image)
return [height, width]
def get_spatial_size(inpt: features.InputTypeJIT) -> List[int]:
if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, features._Feature)):
return get_spatial_size_image_tensor(inpt)
elif isinstance(inpt, features._Feature):
image_size = getattr(inpt, "image_size", None)
if image_size is not None:
return list(image_size)
else:
raise ValueError(f"Type {inpt.__class__} doesn't have spatial size.")
else:
return get_spatial_size_image_pil(inpt)
def _xywh_to_xyxy(xywh: torch.Tensor) -> torch.Tensor: def _xywh_to_xyxy(xywh: torch.Tensor) -> torch.Tensor:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment