"git@developer.sourcefind.cn:OpenDAS/ollama.git" did not exist on "27402cb7a28555a3efcaa5af054b1ce2d18e5442"
Unverified Commit 7039c2c3 authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

port FiveCrop and TenCrop to prototype API (#5513)



* port FiveCrop and TenCrop to prototype API

* fix ten crop for pil

* Update torchvision/prototype/transforms/_geometry.py
Co-authored-by: default avatarPhilip Meier <github.pmeier@posteo.de>

* simplify implementation

* minor cleanup
Co-authored-by: default avatarVasilis Vryniotis <datumbox@users.noreply.github.com>
parent 79892d37
...@@ -7,7 +7,7 @@ from ._transform import Transform # usort: skip ...@@ -7,7 +7,7 @@ from ._transform import Transform # usort: skip
from ._augment import RandomErasing, RandomMixup, RandomCutmix from ._augment import RandomErasing, RandomMixup, RandomCutmix
from ._auto_augment import RandAugment, TrivialAugmentWide, AutoAugment, AugMix from ._auto_augment import RandAugment, TrivialAugmentWide, AutoAugment, AugMix
from ._container import Compose, RandomApply, RandomChoice, RandomOrder from ._container import Compose, RandomApply, RandomChoice, RandomOrder
from ._geometry import HorizontalFlip, Resize, CenterCrop, RandomResizedCrop from ._geometry import HorizontalFlip, Resize, CenterCrop, RandomResizedCrop, FiveCrop, TenCrop, BatchMultiCrop
from ._meta import ConvertBoundingBoxFormat, ConvertImageDtype, ConvertImageColorSpace from ._meta import ConvertBoundingBoxFormat, ConvertImageDtype, ConvertImageColorSpace
from ._misc import Identity, Normalize, ToDtype, Lambda from ._misc import Identity, Normalize, ToDtype, Lambda
from ._presets import CocoEval, ImageNetEval, VocEval, Kinect400Eval, RaftEval from ._presets import CocoEval, ImageNetEval, VocEval, Kinect400Eval, RaftEval
......
import collections.abc
import math import math
import warnings import warnings
from typing import Any, Dict, List, Union, Sequence, Tuple, cast from typing import Any, Dict, List, Union, Sequence, Tuple, cast
...@@ -6,6 +7,7 @@ import PIL.Image ...@@ -6,6 +7,7 @@ import PIL.Image
import torch import torch
from torchvision.prototype import features from torchvision.prototype import features
from torchvision.prototype.transforms import Transform, InterpolationMode, functional as F from torchvision.prototype.transforms import Transform, InterpolationMode, functional as F
from torchvision.transforms.functional import pil_to_tensor
from torchvision.transforms.transforms import _setup_size, _interpolation_modes_from_int from torchvision.transforms.transforms import _setup_size, _interpolation_modes_from_int
from ._utils import query_image, get_image_dimensions, has_any, is_simple_tensor from ._utils import query_image, get_image_dimensions, has_any, is_simple_tensor
...@@ -168,3 +170,89 @@ class RandomResizedCrop(Transform): ...@@ -168,3 +170,89 @@ class RandomResizedCrop(Transform):
if has_any(sample, features.BoundingBox, features.SegmentationMask): if has_any(sample, features.BoundingBox, features.SegmentationMask):
raise TypeError(f"BoundingBox'es and SegmentationMask's are not supported by {type(self).__name__}()") raise TypeError(f"BoundingBox'es and SegmentationMask's are not supported by {type(self).__name__}()")
return super().forward(sample) return super().forward(sample)
class MultiCropResult(list):
"""Helper class for :class:`~torchvision.prototype.transforms.BatchMultiCrop`.
Outputs of multi crop transforms such as :class:`~torchvision.prototype.transforms.FiveCrop` and
`:class:`~torchvision.prototype.transforms.TenCrop` should be wrapped in this in order to be batched correctly by
:class:`~torchvision.prototype.transforms.BatchMultiCrop`.
"""
pass
class FiveCrop(Transform):
def __init__(self, size: Union[int, Sequence[int]]) -> None:
super().__init__()
self.size = _setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.")
def _transform(self, input: Any, params: Dict[str, Any]) -> Any:
if isinstance(input, features.Image):
output = F.five_crop_image_tensor(input, self.size)
return MultiCropResult(features.Image.new_like(input, o) for o in output)
elif is_simple_tensor(input):
return MultiCropResult(F.five_crop_image_tensor(input, self.size))
elif isinstance(input, PIL.Image.Image):
return MultiCropResult(F.five_crop_image_pil(input, self.size))
else:
return input
def forward(self, *inputs: Any) -> Any:
sample = inputs if len(inputs) > 1 else inputs[0]
if has_any(sample, features.BoundingBox, features.SegmentationMask):
raise TypeError(f"BoundingBox'es and SegmentationMask's are not supported by {type(self).__name__}()")
return super().forward(sample)
class TenCrop(Transform):
def __init__(self, size: Union[int, Sequence[int]], vertical_flip: bool = False) -> None:
super().__init__()
self.size = _setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.")
self.vertical_flip = vertical_flip
def _transform(self, input: Any, params: Dict[str, Any]) -> Any:
if isinstance(input, features.Image):
output = F.ten_crop_image_tensor(input, self.size, vertical_flip=self.vertical_flip)
return MultiCropResult(features.Image.new_like(input, o) for o in output)
elif is_simple_tensor(input):
return MultiCropResult(F.ten_crop_image_tensor(input, self.size))
elif isinstance(input, PIL.Image.Image):
return MultiCropResult(F.ten_crop_image_pil(input, self.size))
else:
return input
def forward(self, *inputs: Any) -> Any:
sample = inputs if len(inputs) > 1 else inputs[0]
if has_any(sample, features.BoundingBox, features.SegmentationMask):
raise TypeError(f"BoundingBox'es and SegmentationMask's are not supported by {type(self).__name__}()")
return super().forward(sample)
class BatchMultiCrop(Transform):
def forward(self, *inputs: Any) -> Any:
# This is basically the functionality of `torchvision.prototype.utils._internal.apply_recursively` with one
# significant difference:
# Since we need multiple images to batch them together, we need to explicitly exclude `MultiCropResult` from
# the sequence case.
def apply_recursively(obj: Any) -> Any:
if isinstance(obj, MultiCropResult):
crops = obj
if isinstance(obj[0], PIL.Image.Image):
crops = [pil_to_tensor(crop) for crop in crops] # type: ignore[assignment]
batch = torch.stack(crops)
if isinstance(obj[0], features.Image):
batch = features.Image.new_like(obj[0], batch)
return batch
elif isinstance(obj, collections.abc.Sequence) and not isinstance(obj, str):
return [apply_recursively(item) for item in obj]
elif isinstance(obj, collections.abc.Mapping):
return {key: apply_recursively(item) for key, item in obj.items()}
else:
return obj
return apply_recursively(inputs if len(inputs) > 1 else inputs[0])
...@@ -60,6 +60,10 @@ from ._geometry import ( ...@@ -60,6 +60,10 @@ from ._geometry import (
perspective_image_pil, perspective_image_pil,
vertical_flip_image_tensor, vertical_flip_image_tensor,
vertical_flip_image_pil, vertical_flip_image_pil,
five_crop_image_tensor,
five_crop_image_pil,
ten_crop_image_tensor,
ten_crop_image_pil,
) )
from ._misc import normalize_image_tensor, gaussian_blur_image_tensor from ._misc import normalize_image_tensor, gaussian_blur_image_tensor
from ._type_conversion import decode_image_with_pil, decode_video_with_av, label_to_one_hot from ._type_conversion import decode_image_with_pil, decode_video_with_av, label_to_one_hot
...@@ -314,3 +314,79 @@ def resized_crop_image_pil( ...@@ -314,3 +314,79 @@ def resized_crop_image_pil(
) -> PIL.Image.Image: ) -> PIL.Image.Image:
img = crop_image_pil(img, top, left, height, width) img = crop_image_pil(img, top, left, height, width)
return resize_image_pil(img, size, interpolation=interpolation) return resize_image_pil(img, size, interpolation=interpolation)
def _parse_five_crop_size(size: List[int]) -> List[int]:
if isinstance(size, numbers.Number):
size = (int(size), int(size))
elif isinstance(size, (tuple, list)) and len(size) == 1:
size = (size[0], size[0]) # type: ignore[assignment]
if len(size) != 2:
raise ValueError("Please provide only two dimensions (h, w) for size.")
return size
def five_crop_image_tensor(
img: torch.Tensor, size: List[int]
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
crop_height, crop_width = _parse_five_crop_size(size)
_, image_height, image_width = get_dimensions_image_tensor(img)
if crop_width > image_width or crop_height > image_height:
msg = "Requested crop size {} is bigger than input size {}"
raise ValueError(msg.format(size, (image_height, image_width)))
tl = crop_image_tensor(img, 0, 0, crop_height, crop_width)
tr = crop_image_tensor(img, 0, image_width - crop_width, crop_height, crop_width)
bl = crop_image_tensor(img, image_height - crop_height, 0, crop_height, crop_width)
br = crop_image_tensor(img, image_height - crop_height, image_width - crop_width, crop_height, crop_width)
center = center_crop_image_tensor(img, [crop_height, crop_width])
return tl, tr, bl, br, center
def five_crop_image_pil(
img: PIL.Image.Image, size: List[int]
) -> Tuple[PIL.Image.Image, PIL.Image.Image, PIL.Image.Image, PIL.Image.Image, PIL.Image.Image]:
crop_height, crop_width = _parse_five_crop_size(size)
_, image_height, image_width = get_dimensions_image_pil(img)
if crop_width > image_width or crop_height > image_height:
msg = "Requested crop size {} is bigger than input size {}"
raise ValueError(msg.format(size, (image_height, image_width)))
tl = crop_image_pil(img, 0, 0, crop_height, crop_width)
tr = crop_image_pil(img, 0, image_width - crop_width, crop_height, crop_width)
bl = crop_image_pil(img, image_height - crop_height, 0, crop_height, crop_width)
br = crop_image_pil(img, image_height - crop_height, image_width - crop_width, crop_height, crop_width)
center = center_crop_image_pil(img, [crop_height, crop_width])
return tl, tr, bl, br, center
def ten_crop_image_tensor(img: torch.Tensor, size: List[int], vertical_flip: bool = False) -> List[torch.Tensor]:
tl, tr, bl, br, center = five_crop_image_tensor(img, size)
if vertical_flip:
img = vertical_flip_image_tensor(img)
else:
img = horizontal_flip_image_tensor(img)
tl_flip, tr_flip, bl_flip, br_flip, center_flip = five_crop_image_tensor(img, size)
return [tl, tr, bl, br, center, tl_flip, tr_flip, bl_flip, br_flip, center_flip]
def ten_crop_image_pil(img: PIL.Image.Image, size: List[int], vertical_flip: bool = False) -> List[PIL.Image.Image]:
tl, tr, bl, br, center = five_crop_image_pil(img, size)
if vertical_flip:
img = vertical_flip_image_pil(img)
else:
img = horizontal_flip_image_pil(img)
tl_flip, tr_flip, bl_flip, br_flip, center_flip = five_crop_image_pil(img, size)
return [tl, tr, bl, br, center, tl_flip, tr_flip, bl_flip, br_flip, center_flip]
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment