Unverified Commit cff5435f authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

remove BatchMultiCrop (#6460)

* remove BatchMultiCrop

* address review

* let FiveCrop return tuples
parent aea748b3
...@@ -16,7 +16,6 @@ from ._color import ( ...@@ -16,7 +16,6 @@ from ._color import (
) )
from ._container import Compose, RandomApply, RandomChoice, RandomOrder from ._container import Compose, RandomApply, RandomChoice, RandomOrder
from ._geometry import ( from ._geometry import (
BatchMultiCrop,
CenterCrop, CenterCrop,
ElasticTransform, ElasticTransform,
FiveCrop, FiveCrop,
......
...@@ -8,7 +8,7 @@ import torch ...@@ -8,7 +8,7 @@ import torch
from torchvision.ops.boxes import box_iou from torchvision.ops.boxes import box_iou
from torchvision.prototype import features from torchvision.prototype import features
from torchvision.prototype.transforms import functional as F, Transform from torchvision.prototype.transforms import functional as F, Transform
from torchvision.transforms.functional import InterpolationMode, pil_to_tensor from torchvision.transforms.functional import InterpolationMode
from torchvision.transforms.functional_tensor import _parse_pad_padding from torchvision.transforms.functional_tensor import _parse_pad_padding
from torchvision.transforms.transforms import _check_sequence_input, _setup_angle, _setup_size from torchvision.transforms.transforms import _check_sequence_input, _setup_angle, _setup_size
...@@ -136,30 +136,41 @@ class RandomResizedCrop(Transform): ...@@ -136,30 +136,41 @@ class RandomResizedCrop(Transform):
) )
class MultiCropResult(list): class FiveCrop(Transform):
"""Helper class for :class:`~torchvision.prototype.transforms.BatchMultiCrop`. """
Example:
Outputs of multi crop transforms such as :class:`~torchvision.prototype.transforms.FiveCrop` and >>> class BatchMultiCrop(transforms.Transform):
`:class:`~torchvision.prototype.transforms.TenCrop` should be wrapped in this in order to be batched correctly by ... def forward(self, sample: Tuple[Tuple[features.Image, ...], features.Label]):
:class:`~torchvision.prototype.transforms.BatchMultiCrop`. ... images, labels = sample
... batch_size = len(images)
... images = features.Image.new_like(images[0], torch.stack(images))
... labels = features.Label.new_like(labels, labels.repeat(batch_size))
... return images, labels
...
>>> image = features.Image(torch.rand(3, 256, 256))
>>> label = features.Label(0)
>>> transform = transforms.Compose([transforms.FiveCrop(), BatchMultiCrop()])
>>> images, labels = transform(image, label)
>>> images.shape
torch.Size([5, 3, 224, 224])
>>> labels.shape
torch.Size([5])
""" """
pass
class FiveCrop(Transform):
def __init__(self, size: Union[int, Sequence[int]]) -> None: def __init__(self, size: Union[int, Sequence[int]]) -> None:
super().__init__() super().__init__()
self.size = _setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.") self.size = _setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.")
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
# TODO: returning a list is technically BC breaking since FiveCrop returned a tuple before. We switched to a
# list here to align it with TenCrop.
if isinstance(inpt, features.Image): if isinstance(inpt, features.Image):
output = F.five_crop_image_tensor(inpt, self.size) output = F.five_crop_image_tensor(inpt, self.size)
return MultiCropResult(features.Image.new_like(inpt, o) for o in output) return tuple(features.Image.new_like(inpt, o) for o in output)
elif is_simple_tensor(inpt): elif is_simple_tensor(inpt):
return MultiCropResult(F.five_crop_image_tensor(inpt, self.size)) return F.five_crop_image_tensor(inpt, self.size)
elif isinstance(inpt, PIL.Image.Image): elif isinstance(inpt, PIL.Image.Image):
return MultiCropResult(F.five_crop_image_pil(inpt, self.size)) return F.five_crop_image_pil(inpt, self.size)
else: else:
return inpt return inpt
...@@ -171,6 +182,10 @@ class FiveCrop(Transform): ...@@ -171,6 +182,10 @@ class FiveCrop(Transform):
class TenCrop(Transform): class TenCrop(Transform):
"""
See :class:`~torchvision.prototype.transforms.FiveCrop` for an example.
"""
def __init__(self, size: Union[int, Sequence[int]], vertical_flip: bool = False) -> None: def __init__(self, size: Union[int, Sequence[int]], vertical_flip: bool = False) -> None:
super().__init__() super().__init__()
self.size = _setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.") self.size = _setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.")
...@@ -179,11 +194,11 @@ class TenCrop(Transform): ...@@ -179,11 +194,11 @@ class TenCrop(Transform):
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
if isinstance(inpt, features.Image): if isinstance(inpt, features.Image):
output = F.ten_crop_image_tensor(inpt, self.size, vertical_flip=self.vertical_flip) output = F.ten_crop_image_tensor(inpt, self.size, vertical_flip=self.vertical_flip)
return MultiCropResult(features.Image.new_like(inpt, o) for o in output) return [features.Image.new_like(inpt, o) for o in output]
elif is_simple_tensor(inpt): elif is_simple_tensor(inpt):
return MultiCropResult(F.ten_crop_image_tensor(inpt, self.size, vertical_flip=self.vertical_flip)) return F.ten_crop_image_tensor(inpt, self.size, vertical_flip=self.vertical_flip)
elif isinstance(inpt, PIL.Image.Image): elif isinstance(inpt, PIL.Image.Image):
return MultiCropResult(F.ten_crop_image_pil(inpt, self.size, vertical_flip=self.vertical_flip)) return F.ten_crop_image_pil(inpt, self.size, vertical_flip=self.vertical_flip)
else: else:
return inpt return inpt
...@@ -194,22 +209,6 @@ class TenCrop(Transform): ...@@ -194,22 +209,6 @@ class TenCrop(Transform):
return super().forward(sample) return super().forward(sample)
class BatchMultiCrop(Transform):
_transformed_types = (MultiCropResult,)
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
crops = inpt
if isinstance(inpt[0], PIL.Image.Image):
crops = [pil_to_tensor(crop) for crop in crops]
batch = torch.stack(crops)
if isinstance(inpt[0], features.Image):
batch = features.Image.new_like(inpt[0], batch)
return batch
def _check_fill_arg(fill: Union[int, float, Sequence[int], Sequence[float]]) -> None: def _check_fill_arg(fill: Union[int, float, Sequence[int], Sequence[float]]) -> None:
if not isinstance(fill, (numbers.Number, tuple, list)): if not isinstance(fill, (numbers.Number, tuple, list)):
raise TypeError("Got inappropriate fill arg") raise TypeError("Got inappropriate fill arg")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment