Unverified Commit 54a2d4e8 authored by Vasilis Vryniotis's avatar Vasilis Vryniotis Committed by GitHub
Browse files

[prototype] Video Classes Clean up (#6751)

* Removing unnecessary methods/classes.

* Unions instead of ImageOrVideo types

* Fixing JIT issue.
parent 7d36d263
from ._bounding_box import BoundingBox, BoundingBoxFormat
from ._encoded import EncodedData, EncodedImage, EncodedVideo
from ._encoded import EncodedData, EncodedImage
from ._feature import _Feature, FillType, FillTypeJIT, InputType, InputTypeJIT, is_simple_tensor
from ._image import (
ColorSpace,
......@@ -14,12 +14,10 @@ from ._image import (
from ._label import Label, OneHotLabel
from ._mask import Mask
from ._video import (
ImageOrVideoType,
ImageOrVideoTypeJIT,
LegacyVideoType,
LegacyVideoTypeJIT,
TensorImageOrVideoType,
TensorImageOrVideoTypeJIT,
TensorVideoType,
TensorVideoTypeJIT,
Video,
VideoType,
VideoTypeJIT,
......
......@@ -55,7 +55,3 @@ class EncodedImage(EncodedData):
self._spatial_size = image.height, image.width
return self._spatial_size
class EncodedVideo(EncodedData):
pass
......@@ -6,10 +6,8 @@ from typing import Any, cast, List, Optional, Tuple, Union
import PIL.Image
import torch
from torchvision._utils import StrEnum
from torchvision.transforms.functional import InterpolationMode, to_pil_image
from torchvision.utils import draw_bounding_boxes, make_grid
from torchvision.transforms.functional import InterpolationMode
from ._bounding_box import BoundingBox
from ._feature import _Feature, FillTypeJIT
......@@ -124,16 +122,6 @@ class Image(_Feature):
color_space=color_space,
)
def show(self) -> None:
# TODO: this is useful for developing and debugging but we should remove or at least revisit this before we
# promote this out of the prototype state
to_pil_image(make_grid(self.view(-1, *self.shape[-3:]))).show()
def draw_bounding_box(self, bounding_box: BoundingBox, **kwargs: Any) -> Image:
# TODO: this is useful for developing and debugging but we should remove or at least revisit this before we
# promote this out of the prototype state
return Image.wrap_like(self, draw_bounding_boxes(self, bounding_box.to_format("xyxy").view(-1, 4), **kwargs))
def horizontal_flip(self) -> Image:
output = self._F.horizontal_flip_image_tensor(self)
return Image.wrap_like(self, output)
......
......@@ -7,7 +7,7 @@ import torch
from torchvision.transforms.functional import InterpolationMode
from ._feature import _Feature, FillTypeJIT
from ._image import ColorSpace, ImageType, ImageTypeJIT, TensorImageType, TensorImageTypeJIT
from ._image import ColorSpace
class Video(_Feature):
......@@ -236,9 +236,3 @@ LegacyVideoType = torch.Tensor
LegacyVideoTypeJIT = torch.Tensor
TensorVideoType = Union[torch.Tensor, Video]
TensorVideoTypeJIT = torch.Tensor
# TODO: decide if we should do definitions for both Images and Videos or use unions in the methods
ImageOrVideoType = Union[ImageType, VideoType]
ImageOrVideoTypeJIT = Union[ImageTypeJIT, VideoTypeJIT]
TensorImageOrVideoType = Union[TensorImageType, TensorVideoType]
TensorImageOrVideoTypeJIT = Union[TensorImageTypeJIT, TensorVideoTypeJIT]
import math
import numbers
import warnings
from typing import Any, cast, Dict, List, Optional, Tuple
from typing import Any, cast, Dict, List, Optional, Tuple, Union
import PIL.Image
import torch
......@@ -92,14 +92,15 @@ class RandomErasing(_RandomApplyTransform):
return dict(i=i, j=j, h=h, w=w, v=v)
def _transform(self, inpt: features.ImageOrVideoType, params: Dict[str, Any]) -> features.ImageOrVideoType:
def _transform(
self, inpt: Union[features.ImageType, features.VideoType], params: Dict[str, Any]
) -> Union[features.ImageType, features.VideoType]:
if params["v"] is not None:
inpt = F.erase(inpt, **params, inplace=self.inplace)
return inpt
# TODO: Add support for Video: https://github.com/pytorch/vision/issues/6731
class _BaseMixupCutmix(_RandomApplyTransform):
def __init__(self, alpha: float, p: float = 0.5) -> None:
super().__init__(p=p)
......
......@@ -35,7 +35,7 @@ class _AutoAugmentBase(Transform):
self,
sample: Any,
unsupported_types: Tuple[Type, ...] = (features.BoundingBox, features.Mask),
) -> Tuple[int, features.ImageOrVideoType]:
) -> Tuple[int, Union[features.ImageType, features.VideoType]]:
sample_flat, _ = tree_flatten(sample)
image_or_videos = []
for id, inpt in enumerate(sample_flat):
......@@ -60,12 +60,12 @@ class _AutoAugmentBase(Transform):
def _apply_image_or_video_transform(
self,
image: features.ImageOrVideoType,
image: Union[features.ImageType, features.VideoType],
transform_id: str,
magnitude: float,
interpolation: InterpolationMode,
fill: Dict[Type, features.FillType],
) -> features.ImageOrVideoType:
) -> Union[features.ImageType, features.VideoType]:
fill_ = fill[type(image)]
fill_ = F._geometry._convert_fill_arg(fill_)
......
......@@ -111,8 +111,8 @@ class RandomPhotometricDistort(Transform):
)
def _permute_channels(
self, inpt: features.ImageOrVideoType, permutation: torch.Tensor
) -> features.ImageOrVideoType:
self, inpt: Union[features.ImageType, features.VideoType], permutation: torch.Tensor
) -> Union[features.ImageType, features.VideoType]:
if isinstance(inpt, PIL.Image.Image):
inpt = F.pil_to_tensor(inpt)
......@@ -126,7 +126,9 @@ class RandomPhotometricDistort(Transform):
return output
def _transform(self, inpt: features.ImageOrVideoType, params: Dict[str, Any]) -> features.ImageOrVideoType:
def _transform(
self, inpt: Union[features.ImageType, features.VideoType], params: Dict[str, Any]
) -> Union[features.ImageType, features.VideoType]:
if params["brightness"]:
inpt = F.adjust_brightness(
inpt, brightness_factor=ColorJitter._generate_value(self.brightness[0], self.brightness[1])
......
......@@ -52,7 +52,9 @@ class Grayscale(Transform):
super().__init__()
self.num_output_channels = num_output_channels
def _transform(self, inpt: features.ImageOrVideoType, params: Dict[str, Any]) -> features.ImageOrVideoType:
def _transform(
self, inpt: Union[features.ImageType, features.VideoType], params: Dict[str, Any]
) -> Union[features.ImageType, features.VideoType]:
output = _F.rgb_to_grayscale(inpt, num_output_channels=self.num_output_channels)
if isinstance(inpt, (features.Image, features.Video)):
output = inpt.wrap_like(inpt, output, color_space=features.ColorSpace.GRAY) # type: ignore[arg-type]
......@@ -81,7 +83,9 @@ class RandomGrayscale(_RandomApplyTransform):
num_input_channels, *_ = query_chw(sample)
return dict(num_input_channels=num_input_channels)
def _transform(self, inpt: features.ImageOrVideoType, params: Dict[str, Any]) -> features.ImageOrVideoType:
def _transform(
self, inpt: Union[features.ImageType, features.VideoType], params: Dict[str, Any]
) -> Union[features.ImageType, features.VideoType]:
output = _F.rgb_to_grayscale(inpt, num_output_channels=params["num_input_channels"])
if isinstance(inpt, (features.Image, features.Video)):
output = inpt.wrap_like(inpt, output, color_space=features.ColorSpace.GRAY) # type: ignore[arg-type]
......
......@@ -148,6 +148,9 @@ class RandomResizedCrop(Transform):
)
ImageOrVideoTypeJIT = Union[features.ImageTypeJIT, features.VideoTypeJIT]
class FiveCrop(Transform):
"""
Example:
......@@ -177,14 +180,8 @@ class FiveCrop(Transform):
self.size = _setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.")
def _transform(
self, inpt: features.ImageOrVideoType, params: Dict[str, Any]
) -> Tuple[
features.ImageOrVideoType,
features.ImageOrVideoType,
features.ImageOrVideoType,
features.ImageOrVideoType,
features.ImageOrVideoType,
]:
self, inpt: ImageOrVideoTypeJIT, params: Dict[str, Any]
) -> Tuple[ImageOrVideoTypeJIT, ImageOrVideoTypeJIT, ImageOrVideoTypeJIT, ImageOrVideoTypeJIT, ImageOrVideoTypeJIT]:
return F.five_crop(inpt, self.size)
def forward(self, *inputs: Any) -> Any:
......@@ -205,7 +202,9 @@ class TenCrop(Transform):
self.size = _setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.")
self.vertical_flip = vertical_flip
def _transform(self, inpt: features.ImageOrVideoType, params: Dict[str, Any]) -> List[features.ImageOrVideoType]:
def _transform(
self, inpt: Union[features.ImageType, features.VideoType], params: Dict[str, Any]
) -> Union[List[features.ImageTypeJIT], List[features.VideoTypeJIT]]:
return F.ten_crop(inpt, self.size, vertical_flip=self.vertical_flip)
def forward(self, *inputs: Any) -> Any:
......
......@@ -29,8 +29,8 @@ class ConvertImageDtype(Transform):
self.dtype = dtype
def _transform(
self, inpt: features.TensorImageOrVideoType, params: Dict[str, Any]
) -> features.TensorImageOrVideoType:
self, inpt: Union[features.TensorImageType, features.TensorVideoType], params: Dict[str, Any]
) -> Union[features.TensorImageType, features.TensorVideoType]:
output = F.convert_image_dtype(inpt, dtype=self.dtype)
return (
output if features.is_simple_tensor(inpt) else type(inpt).wrap_like(inpt, output) # type: ignore[attr-defined]
......@@ -58,7 +58,9 @@ class ConvertColorSpace(Transform):
self.copy = copy
def _transform(self, inpt: features.ImageOrVideoType, params: Dict[str, Any]) -> features.ImageOrVideoType:
def _transform(
self, inpt: Union[features.ImageType, features.VideoType], params: Dict[str, Any]
) -> Union[features.ImageType, features.VideoType]:
return F.convert_color_space(
inpt, color_space=self.color_space, old_color_space=self.old_color_space, copy=self.copy
)
......
......@@ -68,7 +68,9 @@ class LinearTransformation(Transform):
return super().forward(*inputs)
def _transform(self, inpt: features.TensorImageOrVideoType, params: Dict[str, Any]) -> torch.Tensor:
def _transform(
self, inpt: Union[features.TensorImageType, features.TensorVideoType], params: Dict[str, Any]
) -> torch.Tensor:
# Image instance after linear transformation is not Image anymore due to unknown data range
# Thus we will return Tensor for input Image
......@@ -101,7 +103,9 @@ class Normalize(Transform):
self.std = list(std)
self.inplace = inplace
def _transform(self, inpt: features.TensorImageOrVideoType, params: Dict[str, Any]) -> torch.Tensor:
def _transform(
self, inpt: Union[features.TensorImageType, features.TensorVideoType], params: Dict[str, Any]
) -> torch.Tensor:
return F.normalize(inpt, mean=self.mean, std=self.std, inplace=self.inplace)
def forward(self, *inpts: Any) -> Any:
......
from typing import Union
import PIL.Image
import torch
......@@ -24,14 +26,14 @@ def erase_video(
def erase(
inpt: features.ImageOrVideoTypeJIT,
inpt: Union[features.ImageTypeJIT, features.VideoTypeJIT],
i: int,
j: int,
h: int,
w: int,
v: torch.Tensor,
inplace: bool = False,
) -> features.ImageOrVideoTypeJIT:
) -> Union[features.ImageTypeJIT, features.VideoTypeJIT]:
if isinstance(inpt, torch.Tensor):
output = erase_image_tensor(inpt, i=i, j=j, h=h, w=w, v=v, inplace=inplace)
if not torch.jit.is_scripting() and isinstance(inpt, (features.Image, features.Video)):
......
......@@ -59,7 +59,7 @@ def to_tensor(inpt: Any) -> torch.Tensor:
return _F.to_tensor(inpt)
def get_image_size(inpt: features.ImageOrVideoTypeJIT) -> List[int]:
def get_image_size(inpt: Union[features.ImageTypeJIT, features.VideoTypeJIT]) -> List[int]:
warnings.warn(
"The function `get_image_size(...)` is deprecated and will be removed in a future release. "
"Instead, please use `get_spatial_size(...)` which returns `[h, w]` instead of `[w, h]`."
......
......@@ -1382,16 +1382,13 @@ def five_crop_video(
return five_crop_image_tensor(video, size)
ImageOrVideoTypeJIT = Union[features.ImageTypeJIT, features.VideoTypeJIT]
def five_crop(
inpt: features.ImageOrVideoTypeJIT, size: List[int]
) -> Tuple[
features.ImageOrVideoTypeJIT,
features.ImageOrVideoTypeJIT,
features.ImageOrVideoTypeJIT,
features.ImageOrVideoTypeJIT,
features.ImageOrVideoTypeJIT,
]:
# TODO: consider breaking BC here to return List[features.ImageOrVideoTypeJIT] to align this op with `ten_crop`
inpt: ImageOrVideoTypeJIT, size: List[int]
) -> Tuple[ImageOrVideoTypeJIT, ImageOrVideoTypeJIT, ImageOrVideoTypeJIT, ImageOrVideoTypeJIT, ImageOrVideoTypeJIT]:
# TODO: consider breaking BC here to return List[features.ImageTypeJIT/VideoTypeJIT] to align this op with `ten_crop`
if isinstance(inpt, torch.Tensor):
output = five_crop_image_tensor(inpt, size)
if not torch.jit.is_scripting() and isinstance(inpt, (features.Image, features.Video)):
......@@ -1434,8 +1431,8 @@ def ten_crop_video(video: torch.Tensor, size: List[int], vertical_flip: bool = F
def ten_crop(
inpt: features.ImageOrVideoTypeJIT, size: List[int], vertical_flip: bool = False
) -> List[features.ImageOrVideoTypeJIT]:
inpt: Union[features.ImageTypeJIT, features.VideoTypeJIT], size: List[int], vertical_flip: bool = False
) -> Union[List[features.ImageTypeJIT], List[features.VideoTypeJIT]]:
if isinstance(inpt, torch.Tensor):
output = ten_crop_image_tensor(inpt, size, vertical_flip=vertical_flip)
if not torch.jit.is_scripting() and isinstance(inpt, (features.Image, features.Video)):
......
from typing import cast, List, Optional, Tuple
from typing import List, Optional, Tuple, Union
import PIL.Image
import torch
......@@ -11,7 +11,7 @@ get_dimensions_image_tensor = _FT.get_dimensions
get_dimensions_image_pil = _FP.get_dimensions
def get_dimensions(image: features.ImageOrVideoTypeJIT) -> List[int]:
def get_dimensions(image: Union[features.ImageTypeJIT, features.VideoTypeJIT]) -> List[int]:
if isinstance(image, torch.Tensor) and (
torch.jit.is_scripting() or not isinstance(image, (features.Image, features.Video))
):
......@@ -32,7 +32,7 @@ def get_num_channels_video(video: torch.Tensor) -> int:
return get_num_channels_image_tensor(video)
def get_num_channels(image: features.ImageOrVideoTypeJIT) -> int:
def get_num_channels(image: Union[features.ImageTypeJIT, features.VideoTypeJIT]) -> int:
if isinstance(image, torch.Tensor) and (
torch.jit.is_scripting() or not isinstance(image, (features.Image, features.Video))
):
......@@ -262,11 +262,11 @@ def convert_color_space_video(
def convert_color_space(
inpt: features.ImageOrVideoTypeJIT,
inpt: Union[features.ImageTypeJIT, features.VideoTypeJIT],
color_space: ColorSpace,
old_color_space: Optional[ColorSpace] = None,
copy: bool = True,
) -> features.ImageOrVideoTypeJIT:
) -> Union[features.ImageTypeJIT, features.VideoTypeJIT]:
if isinstance(inpt, torch.Tensor) and (
torch.jit.is_scripting() or not isinstance(inpt, (features.Image, features.Video))
):
......@@ -281,4 +281,4 @@ def convert_color_space(
elif isinstance(inpt, (features.Image, features.Video)):
return inpt.to_color_space(color_space, copy=copy)
else:
return cast(features.ImageOrVideoTypeJIT, convert_color_space_image_pil(inpt, color_space, copy=copy))
return convert_color_space_image_pil(inpt, color_space, copy=copy)
from typing import List, Optional
from typing import List, Optional, Union
import PIL.Image
import torch
......@@ -14,7 +14,10 @@ def normalize_video(video: torch.Tensor, mean: List[float], std: List[float], in
def normalize(
inpt: features.TensorImageOrVideoTypeJIT, mean: List[float], std: List[float], inplace: bool = False
inpt: Union[features.TensorImageTypeJIT, features.TensorVideoTypeJIT],
mean: List[float],
std: List[float],
inplace: bool = False,
) -> torch.Tensor:
if torch.jit.is_scripting():
correct_type = isinstance(inpt, torch.Tensor)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment