"vscode:/vscode.git/clone" did not exist on "2dec28aeafd159879b4f437de25c10f7d7139679"
Unverified Commit 7039c2c3 authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

port FiveCrop and TenCrop to prototype API (#5513)



* port FiveCrop and TenCrop to prototype API

* fix ten crop for pil

* Update torchvision/prototype/transforms/_geometry.py
Co-authored-by: default avatarPhilip Meier <github.pmeier@posteo.de>

* simplify implementation

* minor cleanup
Co-authored-by: default avatarVasilis Vryniotis <datumbox@users.noreply.github.com>
parent 79892d37
......@@ -7,7 +7,7 @@ from ._transform import Transform # usort: skip
from ._augment import RandomErasing, RandomMixup, RandomCutmix
from ._auto_augment import RandAugment, TrivialAugmentWide, AutoAugment, AugMix
from ._container import Compose, RandomApply, RandomChoice, RandomOrder
from ._geometry import HorizontalFlip, Resize, CenterCrop, RandomResizedCrop
from ._geometry import HorizontalFlip, Resize, CenterCrop, RandomResizedCrop, FiveCrop, TenCrop, BatchMultiCrop
from ._meta import ConvertBoundingBoxFormat, ConvertImageDtype, ConvertImageColorSpace
from ._misc import Identity, Normalize, ToDtype, Lambda
from ._presets import CocoEval, ImageNetEval, VocEval, Kinect400Eval, RaftEval
......
import collections.abc
import math
import warnings
from typing import Any, Dict, List, Union, Sequence, Tuple, cast
......@@ -6,6 +7,7 @@ import PIL.Image
import torch
from torchvision.prototype import features
from torchvision.prototype.transforms import Transform, InterpolationMode, functional as F
from torchvision.transforms.functional import pil_to_tensor
from torchvision.transforms.transforms import _setup_size, _interpolation_modes_from_int
from ._utils import query_image, get_image_dimensions, has_any, is_simple_tensor
......@@ -168,3 +170,89 @@ class RandomResizedCrop(Transform):
if has_any(sample, features.BoundingBox, features.SegmentationMask):
raise TypeError(f"BoundingBox'es and SegmentationMask's are not supported by {type(self).__name__}()")
return super().forward(sample)
class MultiCropResult(list):
"""Helper class for :class:`~torchvision.prototype.transforms.BatchMultiCrop`.
Outputs of multi crop transforms such as :class:`~torchvision.prototype.transforms.FiveCrop` and
`:class:`~torchvision.prototype.transforms.TenCrop` should be wrapped in this in order to be batched correctly by
:class:`~torchvision.prototype.transforms.BatchMultiCrop`.
"""
pass
class FiveCrop(Transform):
def __init__(self, size: Union[int, Sequence[int]]) -> None:
super().__init__()
self.size = _setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.")
def _transform(self, input: Any, params: Dict[str, Any]) -> Any:
if isinstance(input, features.Image):
output = F.five_crop_image_tensor(input, self.size)
return MultiCropResult(features.Image.new_like(input, o) for o in output)
elif is_simple_tensor(input):
return MultiCropResult(F.five_crop_image_tensor(input, self.size))
elif isinstance(input, PIL.Image.Image):
return MultiCropResult(F.five_crop_image_pil(input, self.size))
else:
return input
def forward(self, *inputs: Any) -> Any:
sample = inputs if len(inputs) > 1 else inputs[0]
if has_any(sample, features.BoundingBox, features.SegmentationMask):
raise TypeError(f"BoundingBox'es and SegmentationMask's are not supported by {type(self).__name__}()")
return super().forward(sample)
class TenCrop(Transform):
def __init__(self, size: Union[int, Sequence[int]], vertical_flip: bool = False) -> None:
super().__init__()
self.size = _setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.")
self.vertical_flip = vertical_flip
def _transform(self, input: Any, params: Dict[str, Any]) -> Any:
if isinstance(input, features.Image):
output = F.ten_crop_image_tensor(input, self.size, vertical_flip=self.vertical_flip)
return MultiCropResult(features.Image.new_like(input, o) for o in output)
elif is_simple_tensor(input):
return MultiCropResult(F.ten_crop_image_tensor(input, self.size))
elif isinstance(input, PIL.Image.Image):
return MultiCropResult(F.ten_crop_image_pil(input, self.size))
else:
return input
def forward(self, *inputs: Any) -> Any:
sample = inputs if len(inputs) > 1 else inputs[0]
if has_any(sample, features.BoundingBox, features.SegmentationMask):
raise TypeError(f"BoundingBox'es and SegmentationMask's are not supported by {type(self).__name__}()")
return super().forward(sample)
class BatchMultiCrop(Transform):
def forward(self, *inputs: Any) -> Any:
# This is basically the functionality of `torchvision.prototype.utils._internal.apply_recursively` with one
# significant difference:
# Since we need multiple images to batch them together, we need to explicitly exclude `MultiCropResult` from
# the sequence case.
def apply_recursively(obj: Any) -> Any:
if isinstance(obj, MultiCropResult):
crops = obj
if isinstance(obj[0], PIL.Image.Image):
crops = [pil_to_tensor(crop) for crop in crops] # type: ignore[assignment]
batch = torch.stack(crops)
if isinstance(obj[0], features.Image):
batch = features.Image.new_like(obj[0], batch)
return batch
elif isinstance(obj, collections.abc.Sequence) and not isinstance(obj, str):
return [apply_recursively(item) for item in obj]
elif isinstance(obj, collections.abc.Mapping):
return {key: apply_recursively(item) for key, item in obj.items()}
else:
return obj
return apply_recursively(inputs if len(inputs) > 1 else inputs[0])
......@@ -60,6 +60,10 @@ from ._geometry import (
perspective_image_pil,
vertical_flip_image_tensor,
vertical_flip_image_pil,
five_crop_image_tensor,
five_crop_image_pil,
ten_crop_image_tensor,
ten_crop_image_pil,
)
from ._misc import normalize_image_tensor, gaussian_blur_image_tensor
from ._type_conversion import decode_image_with_pil, decode_video_with_av, label_to_one_hot
......@@ -314,3 +314,79 @@ def resized_crop_image_pil(
) -> PIL.Image.Image:
img = crop_image_pil(img, top, left, height, width)
return resize_image_pil(img, size, interpolation=interpolation)
def _parse_five_crop_size(size: List[int]) -> List[int]:
if isinstance(size, numbers.Number):
size = (int(size), int(size))
elif isinstance(size, (tuple, list)) and len(size) == 1:
size = (size[0], size[0]) # type: ignore[assignment]
if len(size) != 2:
raise ValueError("Please provide only two dimensions (h, w) for size.")
return size
def five_crop_image_tensor(
img: torch.Tensor, size: List[int]
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
crop_height, crop_width = _parse_five_crop_size(size)
_, image_height, image_width = get_dimensions_image_tensor(img)
if crop_width > image_width or crop_height > image_height:
msg = "Requested crop size {} is bigger than input size {}"
raise ValueError(msg.format(size, (image_height, image_width)))
tl = crop_image_tensor(img, 0, 0, crop_height, crop_width)
tr = crop_image_tensor(img, 0, image_width - crop_width, crop_height, crop_width)
bl = crop_image_tensor(img, image_height - crop_height, 0, crop_height, crop_width)
br = crop_image_tensor(img, image_height - crop_height, image_width - crop_width, crop_height, crop_width)
center = center_crop_image_tensor(img, [crop_height, crop_width])
return tl, tr, bl, br, center
def five_crop_image_pil(
img: PIL.Image.Image, size: List[int]
) -> Tuple[PIL.Image.Image, PIL.Image.Image, PIL.Image.Image, PIL.Image.Image, PIL.Image.Image]:
crop_height, crop_width = _parse_five_crop_size(size)
_, image_height, image_width = get_dimensions_image_pil(img)
if crop_width > image_width or crop_height > image_height:
msg = "Requested crop size {} is bigger than input size {}"
raise ValueError(msg.format(size, (image_height, image_width)))
tl = crop_image_pil(img, 0, 0, crop_height, crop_width)
tr = crop_image_pil(img, 0, image_width - crop_width, crop_height, crop_width)
bl = crop_image_pil(img, image_height - crop_height, 0, crop_height, crop_width)
br = crop_image_pil(img, image_height - crop_height, image_width - crop_width, crop_height, crop_width)
center = center_crop_image_pil(img, [crop_height, crop_width])
return tl, tr, bl, br, center
def ten_crop_image_tensor(img: torch.Tensor, size: List[int], vertical_flip: bool = False) -> List[torch.Tensor]:
tl, tr, bl, br, center = five_crop_image_tensor(img, size)
if vertical_flip:
img = vertical_flip_image_tensor(img)
else:
img = horizontal_flip_image_tensor(img)
tl_flip, tr_flip, bl_flip, br_flip, center_flip = five_crop_image_tensor(img, size)
return [tl, tr, bl, br, center, tl_flip, tr_flip, bl_flip, br_flip, center_flip]
def ten_crop_image_pil(img: PIL.Image.Image, size: List[int], vertical_flip: bool = False) -> List[PIL.Image.Image]:
tl, tr, bl, br, center = five_crop_image_pil(img, size)
if vertical_flip:
img = vertical_flip_image_pil(img)
else:
img = horizontal_flip_image_pil(img)
tl_flip, tr_flip, bl_flip, br_flip, center_flip = five_crop_image_pil(img, size)
return [tl, tr, bl, br, center, tl_flip, tr_flip, bl_flip, br_flip, center_flip]
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment