Unverified Commit a3fe870b authored by Vasilis Vryniotis's avatar Vasilis Vryniotis Committed by GitHub
Browse files

Adding support of Video to remaining Transforms and Kernels (#6724)

* Adding support of Video to missed Transforms and Kernels

* Fixing Grayscale Transform.

* Fixing FiveCrop and TenCrop Transforms.

* Fix Linter

* Fix more kernels.

* Add `five_crop_video` and `ten_crop_video` kernels

* Added a TODO.

* Missed Video isinstance

* nits

* Fix bug on AugMix

* Nits and TODOs.

* Reapply Philip's recommendation

* Fix mypy and JIT

* Fixing test
parent 0ab50f5f
...@@ -13,4 +13,14 @@ from ._image import ( ...@@ -13,4 +13,14 @@ from ._image import (
) )
from ._label import Label, OneHotLabel from ._label import Label, OneHotLabel
from ._mask import Mask from ._mask import Mask
from ._video import ImageOrVideoType, ImageOrVideoTypeJIT, TensorImageOrVideoType, TensorImageOrVideoTypeJIT, Video from ._video import (
ImageOrVideoType,
ImageOrVideoTypeJIT,
LegacyVideoType,
LegacyVideoTypeJIT,
TensorImageOrVideoType,
TensorImageOrVideoTypeJIT,
Video,
VideoType,
VideoTypeJIT,
)
...@@ -238,6 +238,7 @@ LegacyVideoTypeJIT = torch.Tensor ...@@ -238,6 +238,7 @@ LegacyVideoTypeJIT = torch.Tensor
TensorVideoType = Union[torch.Tensor, Video] TensorVideoType = Union[torch.Tensor, Video]
TensorVideoTypeJIT = torch.Tensor TensorVideoTypeJIT = torch.Tensor
# TODO: decide if we should do definitions for both Images and Videos or use unions in the methods
ImageOrVideoType = Union[ImageType, VideoType] ImageOrVideoType = Union[ImageType, VideoType]
ImageOrVideoTypeJIT = Union[ImageTypeJIT, VideoTypeJIT] ImageOrVideoTypeJIT = Union[ImageTypeJIT, VideoTypeJIT]
TensorImageOrVideoType = Union[TensorImageType, TensorVideoType] TensorImageOrVideoType = Union[TensorImageType, TensorVideoType]
......
...@@ -99,6 +99,7 @@ class RandomErasing(_RandomApplyTransform): ...@@ -99,6 +99,7 @@ class RandomErasing(_RandomApplyTransform):
return inpt return inpt
# TODO: Add support for Video: https://github.com/pytorch/vision/issues/6731
class _BaseMixupCutmix(_RandomApplyTransform): class _BaseMixupCutmix(_RandomApplyTransform):
def __init__(self, alpha: float, p: float = 0.5) -> None: def __init__(self, alpha: float, p: float = 0.5) -> None:
super().__init__(p=p) super().__init__(p=p)
......
...@@ -483,7 +483,8 @@ class AugMix(_AutoAugmentBase): ...@@ -483,7 +483,8 @@ class AugMix(_AutoAugmentBase):
augmentation_space = self._AUGMENTATION_SPACE if self.all_ops else self._PARTIAL_AUGMENTATION_SPACE augmentation_space = self._AUGMENTATION_SPACE if self.all_ops else self._PARTIAL_AUGMENTATION_SPACE
orig_dims = list(image_or_video.shape) orig_dims = list(image_or_video.shape)
batch = image_or_video.view([1] * max(4 - image_or_video.ndim, 0) + orig_dims) expected_dim = 5 if isinstance(orig_image_or_video, features.Video) else 4
batch = image_or_video.view([1] * max(expected_dim - image_or_video.ndim, 0) + orig_dims)
batch_dims = [batch.size(0)] + [1] * (batch.ndim - 1) batch_dims = [batch.size(0)] + [1] * (batch.ndim - 1)
# Sample the beta weights for combining the original and augmented image or video. To get Beta, we use a # Sample the beta weights for combining the original and augmented image or video. To get Beta, we use a
...@@ -520,7 +521,7 @@ class AugMix(_AutoAugmentBase): ...@@ -520,7 +521,7 @@ class AugMix(_AutoAugmentBase):
mix = mix.view(orig_dims).to(dtype=image_or_video.dtype) mix = mix.view(orig_dims).to(dtype=image_or_video.dtype)
if isinstance(orig_image_or_video, (features.Image, features.Video)): if isinstance(orig_image_or_video, (features.Image, features.Video)):
mix = type(orig_image_or_video).wrap_like(orig_image_or_video, mix) # type: ignore[arg-type] mix = orig_image_or_video.wrap_like(orig_image_or_video, mix) # type: ignore[arg-type]
elif isinstance(orig_image_or_video, PIL.Image.Image): elif isinstance(orig_image_or_video, PIL.Image.Image):
mix = F.to_image_pil(mix) mix = F.to_image_pil(mix)
......
...@@ -119,7 +119,7 @@ class RandomPhotometricDistort(Transform): ...@@ -119,7 +119,7 @@ class RandomPhotometricDistort(Transform):
output = inpt[..., permutation, :, :] output = inpt[..., permutation, :, :]
if isinstance(inpt, (features.Image, features.Video)): if isinstance(inpt, (features.Image, features.Video)):
output = type(inpt).wrap_like(inpt, output, color_space=features.ColorSpace.OTHER) # type: ignore[arg-type] output = inpt.wrap_like(inpt, output, color_space=features.ColorSpace.OTHER) # type: ignore[arg-type]
elif isinstance(inpt, PIL.Image.Image): elif isinstance(inpt, PIL.Image.Image):
output = F.to_image_pil(output) output = F.to_image_pil(output)
......
...@@ -29,7 +29,7 @@ class ToTensor(Transform): ...@@ -29,7 +29,7 @@ class ToTensor(Transform):
class Grayscale(Transform): class Grayscale(Transform):
_transformed_types = (features.Image, PIL.Image.Image, features.is_simple_tensor) _transformed_types = (features.Image, PIL.Image.Image, features.is_simple_tensor, features.Video)
def __init__(self, num_output_channels: Literal[1, 3] = 1) -> None: def __init__(self, num_output_channels: Literal[1, 3] = 1) -> None:
deprecation_msg = ( deprecation_msg = (
...@@ -52,15 +52,15 @@ class Grayscale(Transform): ...@@ -52,15 +52,15 @@ class Grayscale(Transform):
super().__init__() super().__init__()
self.num_output_channels = num_output_channels self.num_output_channels = num_output_channels
def _transform(self, inpt: features.ImageType, params: Dict[str, Any]) -> features.ImageType: def _transform(self, inpt: features.ImageOrVideoType, params: Dict[str, Any]) -> features.ImageOrVideoType:
output = _F.rgb_to_grayscale(inpt, num_output_channels=self.num_output_channels) output = _F.rgb_to_grayscale(inpt, num_output_channels=self.num_output_channels)
if isinstance(inpt, features.Image): if isinstance(inpt, (features.Image, features.Video)):
output = features.Image.wrap_like(inpt, output, color_space=features.ColorSpace.GRAY) output = inpt.wrap_like(inpt, output, color_space=features.ColorSpace.GRAY) # type: ignore[arg-type]
return output return output
class RandomGrayscale(_RandomApplyTransform): class RandomGrayscale(_RandomApplyTransform):
_transformed_types = (features.Image, PIL.Image.Image, features.is_simple_tensor) _transformed_types = (features.Image, PIL.Image.Image, features.is_simple_tensor, features.Video)
def __init__(self, p: float = 0.1) -> None: def __init__(self, p: float = 0.1) -> None:
warnings.warn( warnings.warn(
...@@ -81,8 +81,8 @@ class RandomGrayscale(_RandomApplyTransform): ...@@ -81,8 +81,8 @@ class RandomGrayscale(_RandomApplyTransform):
num_input_channels, _, _ = query_chw(sample) num_input_channels, _, _ = query_chw(sample)
return dict(num_input_channels=num_input_channels) return dict(num_input_channels=num_input_channels)
def _transform(self, inpt: features.ImageType, params: Dict[str, Any]) -> features.ImageType: def _transform(self, inpt: features.ImageOrVideoType, params: Dict[str, Any]) -> features.ImageOrVideoType:
output = _F.rgb_to_grayscale(inpt, num_output_channels=params["num_input_channels"]) output = _F.rgb_to_grayscale(inpt, num_output_channels=params["num_input_channels"])
if isinstance(inpt, features.Image): if isinstance(inpt, (features.Image, features.Video)):
output = features.Image.wrap_like(inpt, output, color_space=features.ColorSpace.GRAY) output = inpt.wrap_like(inpt, output, color_space=features.ColorSpace.GRAY) # type: ignore[arg-type]
return output return output
...@@ -155,12 +155,13 @@ class FiveCrop(Transform): ...@@ -155,12 +155,13 @@ class FiveCrop(Transform):
""" """
Example: Example:
>>> class BatchMultiCrop(transforms.Transform): >>> class BatchMultiCrop(transforms.Transform):
... def forward(self, sample: Tuple[Tuple[features.Image, ...], features.Label]): ... def forward(self, sample: Tuple[Tuple[Union[features.Image, features.Video], ...], features.Label]):
... images, labels = sample ... images_or_videos, labels = sample
... batch_size = len(images) ... batch_size = len(images_or_videos)
... images = features.Image.wrap_like(images[0], torch.stack(images)) ... image_or_video = images_or_videos[0]
... images_or_videos = image_or_video.wrap_like(image_or_video, torch.stack(images_or_videos))
... labels = features.Label.wrap_like(labels, labels.repeat(batch_size)) ... labels = features.Label.wrap_like(labels, labels.repeat(batch_size))
... return images, labels ... return images_or_videos, labels
... ...
>>> image = features.Image(torch.rand(3, 256, 256)) >>> image = features.Image(torch.rand(3, 256, 256))
>>> label = features.Label(0) >>> label = features.Label(0)
...@@ -172,15 +173,21 @@ class FiveCrop(Transform): ...@@ -172,15 +173,21 @@ class FiveCrop(Transform):
torch.Size([5]) torch.Size([5])
""" """
_transformed_types = (features.Image, PIL.Image.Image, features.is_simple_tensor) _transformed_types = (features.Image, PIL.Image.Image, features.is_simple_tensor, features.Video)
def __init__(self, size: Union[int, Sequence[int]]) -> None: def __init__(self, size: Union[int, Sequence[int]]) -> None:
super().__init__() super().__init__()
self.size = _setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.") self.size = _setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.")
def _transform( def _transform(
self, inpt: features.ImageType, params: Dict[str, Any] self, inpt: features.ImageOrVideoType, params: Dict[str, Any]
) -> Tuple[features.ImageType, features.ImageType, features.ImageType, features.ImageType, features.ImageType]: ) -> Tuple[
features.ImageOrVideoType,
features.ImageOrVideoType,
features.ImageOrVideoType,
features.ImageOrVideoType,
features.ImageOrVideoType,
]:
return F.five_crop(inpt, self.size) return F.five_crop(inpt, self.size)
def forward(self, *inputs: Any) -> Any: def forward(self, *inputs: Any) -> Any:
...@@ -194,14 +201,14 @@ class TenCrop(Transform): ...@@ -194,14 +201,14 @@ class TenCrop(Transform):
See :class:`~torchvision.prototype.transforms.FiveCrop` for an example. See :class:`~torchvision.prototype.transforms.FiveCrop` for an example.
""" """
_transformed_types = (features.Image, PIL.Image.Image, features.is_simple_tensor) _transformed_types = (features.Image, PIL.Image.Image, features.is_simple_tensor, features.Video)
def __init__(self, size: Union[int, Sequence[int]], vertical_flip: bool = False) -> None: def __init__(self, size: Union[int, Sequence[int]], vertical_flip: bool = False) -> None:
super().__init__() super().__init__()
self.size = _setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.") self.size = _setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.")
self.vertical_flip = vertical_flip self.vertical_flip = vertical_flip
def _transform(self, inpt: features.ImageType, params: Dict[str, Any]) -> List[features.ImageType]: def _transform(self, inpt: features.ImageOrVideoType, params: Dict[str, Any]) -> List[features.ImageOrVideoType]:
return F.ten_crop(inpt, self.size, vertical_flip=self.vertical_flip) return F.ten_crop(inpt, self.size, vertical_flip=self.vertical_flip)
def forward(self, *inputs: Any) -> Any: def forward(self, *inputs: Any) -> Any:
......
...@@ -22,18 +22,18 @@ class ConvertBoundingBoxFormat(Transform): ...@@ -22,18 +22,18 @@ class ConvertBoundingBoxFormat(Transform):
class ConvertImageDtype(Transform): class ConvertImageDtype(Transform):
_transformed_types = (features.is_simple_tensor, features.Image) _transformed_types = (features.is_simple_tensor, features.Image, features.Video)
def __init__(self, dtype: torch.dtype = torch.float32) -> None: def __init__(self, dtype: torch.dtype = torch.float32) -> None:
super().__init__() super().__init__()
self.dtype = dtype self.dtype = dtype
def _transform(self, inpt: features.TensorImageType, params: Dict[str, Any]) -> features.TensorImageType: def _transform(
self, inpt: features.TensorImageOrVideoType, params: Dict[str, Any]
) -> features.TensorImageOrVideoType:
output = F.convert_image_dtype(inpt, dtype=self.dtype) output = F.convert_image_dtype(inpt, dtype=self.dtype)
return ( return (
output output if features.is_simple_tensor(inpt) else type(inpt).wrap_like(inpt, output) # type: ignore[attr-defined]
if features.is_simple_tensor(inpt)
else features.Image.wrap_like(inpt, output) # type: ignore[arg-type]
) )
......
...@@ -140,6 +140,7 @@ class GaussianBlur(Transform): ...@@ -140,6 +140,7 @@ class GaussianBlur(Transform):
return F.gaussian_blur(inpt, self.kernel_size, **params) return F.gaussian_blur(inpt, self.kernel_size, **params)
# TODO: Enhance as described at https://github.com/pytorch/vision/issues/6697
class ToDtype(Lambda): class ToDtype(Lambda):
def __init__(self, dtype: torch.dtype, *types: Type) -> None: def __init__(self, dtype: torch.dtype, *types: Type) -> None:
self.dtype = dtype self.dtype = dtype
......
...@@ -96,6 +96,7 @@ from ._geometry import ( ...@@ -96,6 +96,7 @@ from ._geometry import (
five_crop, five_crop,
five_crop_image_pil, five_crop_image_pil,
five_crop_image_tensor, five_crop_image_tensor,
five_crop_video,
hflip, # TODO: Consider moving all pure alias definitions at the bottom of the file hflip, # TODO: Consider moving all pure alias definitions at the bottom of the file
horizontal_flip, horizontal_flip,
horizontal_flip_bounding_box, horizontal_flip_bounding_box,
...@@ -136,6 +137,7 @@ from ._geometry import ( ...@@ -136,6 +137,7 @@ from ._geometry import (
ten_crop, ten_crop,
ten_crop_image_pil, ten_crop_image_pil,
ten_crop_image_tensor, ten_crop_image_tensor,
ten_crop_video,
vertical_flip, vertical_flip,
vertical_flip_bounding_box, vertical_flip_bounding_box,
vertical_flip_image_pil, vertical_flip_image_pil,
......
...@@ -35,7 +35,7 @@ def erase( ...@@ -35,7 +35,7 @@ def erase(
if isinstance(inpt, torch.Tensor): if isinstance(inpt, torch.Tensor):
output = erase_image_tensor(inpt, i=i, j=j, h=h, w=w, v=v, inplace=inplace) output = erase_image_tensor(inpt, i=i, j=j, h=h, w=w, v=v, inplace=inplace)
if not torch.jit.is_scripting() and isinstance(inpt, (features.Image, features.Video)): if not torch.jit.is_scripting() and isinstance(inpt, (features.Image, features.Video)):
output = type(inpt).wrap_like(inpt, output) # type: ignore[arg-type] output = inpt.wrap_like(inpt, output) # type: ignore[arg-type]
return output return output
else: # isinstance(inpt, PIL.Image.Image): else: # isinstance(inpt, PIL.Image.Image):
return erase_image_pil(inpt, i=i, j=j, h=h, w=w, v=v, inplace=inplace) return erase_image_pil(inpt, i=i, j=j, h=h, w=w, v=v, inplace=inplace)
import warnings import warnings
from typing import Any, List from typing import Any, List, Union
import PIL.Image import PIL.Image
import torch import torch
...@@ -22,10 +22,13 @@ def to_grayscale(inpt: PIL.Image.Image, num_output_channels: int = 1) -> PIL.Ima ...@@ -22,10 +22,13 @@ def to_grayscale(inpt: PIL.Image.Image, num_output_channels: int = 1) -> PIL.Ima
return _F.to_grayscale(inpt, num_output_channels=num_output_channels) return _F.to_grayscale(inpt, num_output_channels=num_output_channels)
def rgb_to_grayscale(inpt: features.LegacyImageTypeJIT, num_output_channels: int = 1) -> features.LegacyImageTypeJIT: def rgb_to_grayscale(
inpt: Union[features.LegacyImageTypeJIT, features.LegacyVideoTypeJIT], num_output_channels: int = 1
) -> Union[features.LegacyImageTypeJIT, features.LegacyVideoTypeJIT]:
old_color_space = ( old_color_space = (
features._image._from_tensor_shape(inpt.shape) # type: ignore[arg-type] features._image._from_tensor_shape(inpt.shape) # type: ignore[arg-type]
if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, features.Image)) if isinstance(inpt, torch.Tensor)
and (torch.jit.is_scripting() or not isinstance(inpt, (features.Image, features.Video)))
else None else None
) )
...@@ -56,7 +59,7 @@ def to_tensor(inpt: Any) -> torch.Tensor: ...@@ -56,7 +59,7 @@ def to_tensor(inpt: Any) -> torch.Tensor:
return _F.to_tensor(inpt) return _F.to_tensor(inpt)
def get_image_size(inpt: features.ImageTypeJIT) -> List[int]: def get_image_size(inpt: features.ImageOrVideoTypeJIT) -> List[int]:
warnings.warn( warnings.warn(
"The function `get_image_size(...)` is deprecated and will be removed in a future release. " "The function `get_image_size(...)` is deprecated and will be removed in a future release. "
"Instead, please use `get_spatial_size(...)` which returns `[h, w]` instead of `[w, h]`." "Instead, please use `get_spatial_size(...)` which returns `[h, w]` instead of `[w, h]`."
......
...@@ -1376,16 +1376,27 @@ def five_crop_image_pil( ...@@ -1376,16 +1376,27 @@ def five_crop_image_pil(
return tl, tr, bl, br, center return tl, tr, bl, br, center
def five_crop_video(
video: torch.Tensor, size: List[int]
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
return five_crop_image_tensor(video, size)
def five_crop( def five_crop(
inpt: features.ImageTypeJIT, size: List[int] inpt: features.ImageOrVideoTypeJIT, size: List[int]
) -> Tuple[ ) -> Tuple[
features.ImageTypeJIT, features.ImageTypeJIT, features.ImageTypeJIT, features.ImageTypeJIT, features.ImageTypeJIT features.ImageOrVideoTypeJIT,
features.ImageOrVideoTypeJIT,
features.ImageOrVideoTypeJIT,
features.ImageOrVideoTypeJIT,
features.ImageOrVideoTypeJIT,
]: ]:
# TODO: consider breaking BC here to return List[features.ImageTypeJIT] to align this op with `ten_crop` # TODO: consider breaking BC here to return List[features.ImageOrVideoTypeJIT] to align this op with `ten_crop`
if isinstance(inpt, torch.Tensor): if isinstance(inpt, torch.Tensor):
output = five_crop_image_tensor(inpt, size) output = five_crop_image_tensor(inpt, size)
if not torch.jit.is_scripting() and isinstance(inpt, features.Image): if not torch.jit.is_scripting() and isinstance(inpt, (features.Image, features.Video)):
output = tuple(features.Image.wrap_like(inpt, item) for item in output) # type: ignore[assignment] tmp = tuple(inpt.wrap_like(inpt, item) for item in output) # type: ignore[arg-type]
output = tmp # type: ignore[assignment]
return output return output
else: # isinstance(inpt, PIL.Image.Image): else: # isinstance(inpt, PIL.Image.Image):
return five_crop_image_pil(inpt, size) return five_crop_image_pil(inpt, size)
...@@ -1418,11 +1429,17 @@ def ten_crop_image_pil(image: PIL.Image.Image, size: List[int], vertical_flip: b ...@@ -1418,11 +1429,17 @@ def ten_crop_image_pil(image: PIL.Image.Image, size: List[int], vertical_flip: b
return [tl, tr, bl, br, center, tl_flip, tr_flip, bl_flip, br_flip, center_flip] return [tl, tr, bl, br, center, tl_flip, tr_flip, bl_flip, br_flip, center_flip]
def ten_crop(inpt: features.ImageTypeJIT, size: List[int], vertical_flip: bool = False) -> List[features.ImageTypeJIT]: def ten_crop_video(video: torch.Tensor, size: List[int], vertical_flip: bool = False) -> List[torch.Tensor]:
return ten_crop_image_tensor(video, size, vertical_flip=vertical_flip)
def ten_crop(
inpt: features.ImageOrVideoTypeJIT, size: List[int], vertical_flip: bool = False
) -> List[features.ImageOrVideoTypeJIT]:
if isinstance(inpt, torch.Tensor): if isinstance(inpt, torch.Tensor):
output = ten_crop_image_tensor(inpt, size, vertical_flip=vertical_flip) output = ten_crop_image_tensor(inpt, size, vertical_flip=vertical_flip)
if not torch.jit.is_scripting() and isinstance(inpt, features.Image): if not torch.jit.is_scripting() and isinstance(inpt, (features.Image, features.Video)):
output = [features.Image.wrap_like(inpt, item) for item in output] output = [inpt.wrap_like(inpt, item) for item in output] # type: ignore[arg-type]
return output return output
else: # isinstance(inpt, PIL.Image.Image): else: # isinstance(inpt, PIL.Image.Image):
return ten_crop_image_pil(inpt, size, vertical_flip=vertical_flip) return ten_crop_image_pil(inpt, size, vertical_flip=vertical_flip)
...@@ -55,6 +55,10 @@ def get_spatial_size_image_pil(image: PIL.Image.Image) -> List[int]: ...@@ -55,6 +55,10 @@ def get_spatial_size_image_pil(image: PIL.Image.Image) -> List[int]:
return [height, width] return [height, width]
# TODO: Should we have get_spatial_size_video here? How about masks/bbox etc? What is the criterion for deciding when
# a kernel will be created?
def get_spatial_size(inpt: features.InputTypeJIT) -> List[int]: def get_spatial_size(inpt: features.InputTypeJIT) -> List[int]:
if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, features._Feature)): if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, features._Feature)):
return get_spatial_size_image_tensor(inpt) return get_spatial_size_image_tensor(inpt)
...@@ -246,7 +250,7 @@ def convert_color_space( ...@@ -246,7 +250,7 @@ def convert_color_space(
): ):
if old_color_space is None: if old_color_space is None:
raise RuntimeError( raise RuntimeError(
"In order to convert the color space of simple tensor images, " "In order to convert the color space of simple tensors, "
"the `old_color_space=...` parameter needs to be passed." "the `old_color_space=...` parameter needs to be passed."
) )
return convert_color_space_image_tensor( return convert_color_space_image_tensor(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment