"vscode:/vscode.git/clone" did not exist on "5e8a21168c5705faa9792829cd9b10808ed83b50"
Unverified Commit 316cc25c authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

Ten crop annotation (#7254)

parent f0b70002
......@@ -234,7 +234,18 @@ class TenCrop(Transform):
def _transform(
self, inpt: Union[datapoints.ImageType, datapoints.VideoType], params: Dict[str, Any]
) -> Union[List[datapoints.ImageTypeJIT], List[datapoints.VideoTypeJIT]]:
) -> Tuple[
ImageOrVideoTypeJIT,
ImageOrVideoTypeJIT,
ImageOrVideoTypeJIT,
ImageOrVideoTypeJIT,
ImageOrVideoTypeJIT,
ImageOrVideoTypeJIT,
ImageOrVideoTypeJIT,
ImageOrVideoTypeJIT,
ImageOrVideoTypeJIT,
ImageOrVideoTypeJIT,
]:
return F.ten_crop(inpt, self.size, vertical_flip=self.vertical_flip)
......
......@@ -1964,8 +1964,6 @@ def five_crop(
if not torch.jit.is_scripting():
_log_api_usage_once(five_crop)
# TODO: consider breaking BC here to return List[datapoints.ImageTypeJIT/VideoTypeJIT] to align this op with
# `ten_crop`
if torch.jit.is_scripting() or is_simple_tensor(inpt):
return five_crop_image_tensor(inpt, size)
elif isinstance(inpt, datapoints.Image):
......@@ -1983,40 +1981,90 @@ def five_crop(
)
def ten_crop_image_tensor(image: torch.Tensor, size: List[int], vertical_flip: bool = False) -> List[torch.Tensor]:
tl, tr, bl, br, center = five_crop_image_tensor(image, size)
def ten_crop_image_tensor(
image: torch.Tensor, size: List[int], vertical_flip: bool = False
) -> Tuple[
torch.Tensor,
torch.Tensor,
torch.Tensor,
torch.Tensor,
torch.Tensor,
torch.Tensor,
torch.Tensor,
torch.Tensor,
torch.Tensor,
torch.Tensor,
]:
non_flipped = five_crop_image_tensor(image, size)
if vertical_flip:
image = vertical_flip_image_tensor(image)
else:
image = horizontal_flip_image_tensor(image)
tl_flip, tr_flip, bl_flip, br_flip, center_flip = five_crop_image_tensor(image, size)
flipped = five_crop_image_tensor(image, size)
return [tl, tr, bl, br, center, tl_flip, tr_flip, bl_flip, br_flip, center_flip]
return non_flipped + flipped
@torch.jit.unused
def ten_crop_image_pil(image: PIL.Image.Image, size: List[int], vertical_flip: bool = False) -> List[PIL.Image.Image]:
tl, tr, bl, br, center = five_crop_image_pil(image, size)
def ten_crop_image_pil(
image: PIL.Image.Image, size: List[int], vertical_flip: bool = False
) -> Tuple[
PIL.Image.Image,
PIL.Image.Image,
PIL.Image.Image,
PIL.Image.Image,
PIL.Image.Image,
PIL.Image.Image,
PIL.Image.Image,
PIL.Image.Image,
PIL.Image.Image,
PIL.Image.Image,
]:
non_flipped = five_crop_image_pil(image, size)
if vertical_flip:
image = vertical_flip_image_pil(image)
else:
image = horizontal_flip_image_pil(image)
tl_flip, tr_flip, bl_flip, br_flip, center_flip = five_crop_image_pil(image, size)
return [tl, tr, bl, br, center, tl_flip, tr_flip, bl_flip, br_flip, center_flip]
def ten_crop_video(video: torch.Tensor, size: List[int], vertical_flip: bool = False) -> List[torch.Tensor]:
flipped = five_crop_image_pil(image, size)
return non_flipped + flipped
def ten_crop_video(
video: torch.Tensor, size: List[int], vertical_flip: bool = False
) -> Tuple[
torch.Tensor,
torch.Tensor,
torch.Tensor,
torch.Tensor,
torch.Tensor,
torch.Tensor,
torch.Tensor,
torch.Tensor,
torch.Tensor,
torch.Tensor,
]:
return ten_crop_image_tensor(video, size, vertical_flip=vertical_flip)
def ten_crop(
inpt: Union[datapoints.ImageTypeJIT, datapoints.VideoTypeJIT], size: List[int], vertical_flip: bool = False
) -> Union[List[datapoints.ImageTypeJIT], List[datapoints.VideoTypeJIT]]:
) -> Tuple[
ImageOrVideoTypeJIT,
ImageOrVideoTypeJIT,
ImageOrVideoTypeJIT,
ImageOrVideoTypeJIT,
ImageOrVideoTypeJIT,
ImageOrVideoTypeJIT,
ImageOrVideoTypeJIT,
ImageOrVideoTypeJIT,
ImageOrVideoTypeJIT,
ImageOrVideoTypeJIT,
]:
if not torch.jit.is_scripting():
_log_api_usage_once(ten_crop)
......
......@@ -827,7 +827,9 @@ def five_crop(img: Tensor, size: List[int]) -> Tuple[Tensor, Tensor, Tensor, Ten
return tl, tr, bl, br, center
def ten_crop(img: Tensor, size: List[int], vertical_flip: bool = False) -> List[Tensor]:
def ten_crop(
img: Tensor, size: List[int], vertical_flip: bool = False
) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor]:
"""Generate ten cropped images from the given image.
Crop the given image into four corners and the central crop plus the
flipped version of these (horizontal flipping is used by default).
......
......@@ -1049,7 +1049,7 @@ class TenCrop(torch.nn.Module):
Example:
>>> transform = Compose([
>>> TenCrop(size), # this is a list of PIL Images
>>> TenCrop(size), # this is a tuple of PIL Images
>>> Lambda(lambda crops: torch.stack([PILToTensor()(crop) for crop in crops])) # returns a 4D tensor
>>> ])
>>> #In your test loop you can do the following:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment