Unverified Commit 3118fb52 authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

add Video feature and kernels (#6667)

* add video feature

* add video kernels

* add video testing utils

* add one kernel info

* fix kernel names in Video feature

* use only uint8 for video testing

* require at least 4 dims for Video feature

* add TODO for image_size -> spatial_size

* image -> video in feature constructor

* introduce new combined images and video type

* add video to transform utils

* fix transforms test

* fix auto augment

* cleanup

* address review comments

* add remaining video kernel infos

* add batch dimension squashing to some kernels

* fix tests and kernel infos

* add xfails for arbitrary batch sizes on some kernels

* fix test setup

* fix equalize_image_tensor for multi batch dims

* fix adjust_sharpness_image_tensor for multi batch dims

* address review comments
parent 7eb5d7fc
......@@ -45,6 +45,8 @@ __all__ = [
"make_segmentation_masks",
"make_mask_loaders",
"make_masks",
"make_video",
"make_videos",
]
......@@ -210,17 +212,19 @@ DEFAULT_EXTRA_DIMS = (*VALID_EXTRA_DIMS, *DEGENERATE_BATCH_DIMS)
def from_loader(loader_fn):
def wrapper(*args, **kwargs):
device = kwargs.pop("device", "cpu")
loader = loader_fn(*args, **kwargs)
return loader.load(kwargs.get("device", "cpu"))
return loader.load(device)
return wrapper
def from_loaders(loaders_fn):
def wrapper(*args, **kwargs):
device = kwargs.pop("device", "cpu")
loaders = loaders_fn(*args, **kwargs)
for loader in loaders:
yield loader.load(kwargs.get("device", "cpu"))
yield loader.load(device)
return wrapper
......@@ -246,6 +250,21 @@ class ImageLoader(TensorLoader):
self.num_channels = self.shape[-3]
NUM_CHANNELS_MAP = {
features.ColorSpace.GRAY: 1,
features.ColorSpace.GRAY_ALPHA: 2,
features.ColorSpace.RGB: 3,
features.ColorSpace.RGB_ALPHA: 4,
}
def get_num_channels(color_space):
num_channels = NUM_CHANNELS_MAP.get(color_space)
if not num_channels:
raise pytest.UsageError(f"Can't determine the number of channels for color space {color_space}")
return num_channels
def make_image_loader(
size="random",
*,
......@@ -255,16 +274,7 @@ def make_image_loader(
constant_alpha=True,
):
size = _parse_image_size(size)
try:
num_channels = {
features.ColorSpace.GRAY: 1,
features.ColorSpace.GRAY_ALPHA: 2,
features.ColorSpace.RGB: 3,
features.ColorSpace.RGB_ALPHA: 4,
}[color_space]
except KeyError as error:
raise pytest.UsageError(f"Can't determine the number of channels for color space {color_space}") from error
num_channels = get_num_channels(color_space)
def fn(shape, dtype, device):
max_value = get_max_value(dtype)
......@@ -531,3 +541,50 @@ def make_mask_loaders(
make_masks = from_loaders(make_mask_loaders)
class VideoLoader(ImageLoader):
pass
def make_video_loader(
size="random",
*,
color_space=features.ColorSpace.RGB,
num_frames="random",
extra_dims=(),
dtype=torch.uint8,
):
size = _parse_image_size(size)
num_frames = int(torch.randint(1, 5, ())) if num_frames == "random" else num_frames
def fn(shape, dtype, device):
video = make_image(size=shape[-2:], color_space=color_space, extra_dims=shape[:-3], dtype=dtype, device=device)
return features.Video(video, color_space=color_space)
return VideoLoader(
fn, shape=(*extra_dims, num_frames, get_num_channels(color_space), *size), dtype=dtype, color_space=color_space
)
make_video = from_loader(make_video_loader)
def make_video_loaders(
*,
sizes=DEFAULT_IMAGE_SIZES,
color_spaces=(
features.ColorSpace.GRAY,
features.ColorSpace.RGB,
),
num_frames=(1, 0, "random"),
extra_dims=DEFAULT_EXTRA_DIMS,
dtypes=(torch.uint8,),
):
for params in combinations_grid(
size=sizes, color_space=color_spaces, num_frames=num_frames, extra_dims=extra_dims, dtype=dtypes
):
yield make_video_loader(**params)
make_videos = from_loaders(make_video_loaders)
......@@ -127,6 +127,23 @@ xfail_dispatch_pil_if_fill_sequence_needs_broadcast = TestMark(
)
def xfail_all_tests(*, reason, condition):
return [
TestMark(("TestDispatchers", test_name), pytest.mark.xfail(reason=reason), condition=condition)
for test_name in [
"test_scripted_smoke",
"test_dispatch_simple_tensor",
"test_dispatch_feature",
]
]
xfails_degenerate_or_multi_batch_dims = xfail_all_tests(
reason="See https://github.com/pytorch/vision/issues/6670 for details.",
condition=lambda args_kwargs: len(args_kwargs.args[0].shape) > 4 or not all(args_kwargs.args[0].shape[:-3]),
)
DISPATCHER_INFOS = [
DispatcherInfo(
F.horizontal_flip,
......@@ -243,6 +260,7 @@ DISPATCHER_INFOS = [
pil_kernel_info=PILKernelInfo(F.perspective_image_pil),
test_marks=[
xfail_dispatch_pil_if_fill_sequence_needs_broadcast,
*xfails_degenerate_or_multi_batch_dims,
],
),
DispatcherInfo(
......@@ -253,6 +271,7 @@ DISPATCHER_INFOS = [
features.Mask: F.elastic_mask,
},
pil_kernel_info=PILKernelInfo(F.elastic_image_pil),
test_marks=xfails_degenerate_or_multi_batch_dims,
),
DispatcherInfo(
F.center_crop,
......@@ -275,6 +294,7 @@ DISPATCHER_INFOS = [
test_marks=[
xfail_jit_python_scalar_arg("kernel_size"),
xfail_jit_python_scalar_arg("sigma"),
*xfails_degenerate_or_multi_batch_dims,
],
),
DispatcherInfo(
......
This diff is collapsed.
......@@ -17,6 +17,7 @@ from prototype_common_utils import (
make_masks,
make_one_hot_labels,
make_segmentation_mask,
make_videos,
)
from torchvision.ops.boxes import box_iou
from torchvision.prototype import features, transforms
......@@ -65,6 +66,7 @@ def parametrize_from_transforms(*transforms):
make_vanilla_tensor_images,
make_pil_images,
make_masks,
make_videos,
]:
inputs = list(creation_fn())
try:
......@@ -155,12 +157,14 @@ class TestSmoke:
features.ColorSpace.RGB,
],
dtypes=[torch.uint8],
extra_dims=[(4,)],
extra_dims=[(), (4,)],
**(dict(num_frames=["random"]) if fn is make_videos else dict()),
)
for fn in [
make_images,
make_vanilla_tensor_images,
make_pil_images,
make_videos,
]
),
)
......@@ -184,6 +188,7 @@ class TestSmoke:
for fn in [
make_images,
make_vanilla_tensor_images,
make_videos,
]
),
),
......@@ -200,6 +205,7 @@ class TestSmoke:
make_images(extra_dims=[(4,)]),
make_vanilla_tensor_images(),
make_pil_images(),
make_videos(extra_dims=[()]),
),
)
]
......@@ -218,6 +224,7 @@ class TestSmoke:
make_images,
make_vanilla_tensor_images,
make_pil_images,
make_videos,
)
]
),
......
......@@ -129,6 +129,7 @@ class TestKernels:
# type all kernels should also work without differentiating between the two. Thus, we go with 2 here as
# common ground.
features.Mask: 2,
features.Video: 4,
}.get(feature_type)
if data_dims is None:
raise pytest.UsageError(
......
......@@ -13,3 +13,4 @@ from ._image import (
)
from ._label import Label, OneHotLabel
from ._mask import Mask
from ._video import ImageOrVideoType, ImageOrVideoTypeJIT, TensorImageOrVideoType, TensorImageOrVideoTypeJIT, Video
from __future__ import annotations
import warnings
from typing import Any, cast, List, Optional, Tuple, Union
import torch
from torchvision.transforms.functional import InterpolationMode
from ._feature import _Feature, FillTypeJIT
from ._image import ColorSpace, ImageType, ImageTypeJIT, TensorImageType, TensorImageTypeJIT
class Video(_Feature):
color_space: ColorSpace
def __new__(
cls,
data: Any,
*,
color_space: Optional[Union[ColorSpace, str]] = None,
dtype: Optional[torch.dtype] = None,
device: Optional[Union[torch.device, str, int]] = None,
requires_grad: bool = False,
) -> Video:
data = torch.as_tensor(data, dtype=dtype, device=device)
if data.ndim < 4:
raise ValueError
video = super().__new__(cls, data, requires_grad=requires_grad)
if color_space is None:
color_space = ColorSpace.from_tensor_shape(video.shape) # type: ignore[arg-type]
if color_space == ColorSpace.OTHER:
warnings.warn("Unable to guess a specific color space. Consider passing it explicitly.")
elif isinstance(color_space, str):
color_space = ColorSpace.from_str(color_space.upper())
elif not isinstance(color_space, ColorSpace):
raise ValueError
video.color_space = color_space
return video
def __repr__(self, *, tensor_contents: Any = None) -> str: # type: ignore[override]
return self._make_repr(color_space=self.color_space)
@classmethod
def new_like(
cls, other: Video, data: Any, *, color_space: Optional[Union[ColorSpace, str]] = None, **kwargs: Any
) -> Video:
return super().new_like(
other, data, color_space=color_space if color_space is not None else other.color_space, **kwargs
)
# TODO: rename this (and all instances of this term to spatial size)
@property
def image_size(self) -> Tuple[int, int]:
return cast(Tuple[int, int], tuple(self.shape[-2:]))
@property
def num_channels(self) -> int:
return self.shape[-3]
@property
def num_frames(self) -> int:
return self.shape[-4]
def to_color_space(self, color_space: Union[str, ColorSpace], copy: bool = True) -> Video:
if isinstance(color_space, str):
color_space = ColorSpace.from_str(color_space.upper())
return Video.new_like(
self,
self._F.convert_color_space_video(
self, old_color_space=self.color_space, new_color_space=color_space, copy=copy
),
color_space=color_space,
)
def horizontal_flip(self) -> Video:
output = self._F.horizontal_flip_video(self)
return Video.new_like(self, output)
def vertical_flip(self) -> Video:
output = self._F.vertical_flip_video(self)
return Video.new_like(self, output)
def resize( # type: ignore[override]
self,
size: List[int],
interpolation: InterpolationMode = InterpolationMode.BILINEAR,
max_size: Optional[int] = None,
antialias: bool = False,
) -> Video:
output = self._F.resize_video(self, size, interpolation=interpolation, max_size=max_size, antialias=antialias)
return Video.new_like(self, output)
def crop(self, top: int, left: int, height: int, width: int) -> Video:
output = self._F.crop_video(self, top, left, height, width)
return Video.new_like(self, output)
def center_crop(self, output_size: List[int]) -> Video:
output = self._F.center_crop_video(self, output_size=output_size)
return Video.new_like(self, output)
def resized_crop(
self,
top: int,
left: int,
height: int,
width: int,
size: List[int],
interpolation: InterpolationMode = InterpolationMode.BILINEAR,
antialias: bool = False,
) -> Video:
output = self._F.resized_crop_video(
self, top, left, height, width, size=list(size), interpolation=interpolation, antialias=antialias
)
return Video.new_like(self, output)
def pad(
self,
padding: Union[int, List[int]],
fill: FillTypeJIT = None,
padding_mode: str = "constant",
) -> Video:
output = self._F.pad_video(self, padding, fill=fill, padding_mode=padding_mode)
return Video.new_like(self, output)
def rotate(
self,
angle: float,
interpolation: InterpolationMode = InterpolationMode.NEAREST,
expand: bool = False,
fill: FillTypeJIT = None,
center: Optional[List[float]] = None,
) -> Video:
output = self._F._geometry.rotate_video(
self, angle, interpolation=interpolation, expand=expand, fill=fill, center=center
)
return Video.new_like(self, output)
def affine(
self,
angle: Union[int, float],
translate: List[float],
scale: float,
shear: List[float],
interpolation: InterpolationMode = InterpolationMode.NEAREST,
fill: FillTypeJIT = None,
center: Optional[List[float]] = None,
) -> Video:
output = self._F._geometry.affine_video(
self,
angle,
translate=translate,
scale=scale,
shear=shear,
interpolation=interpolation,
fill=fill,
center=center,
)
return Video.new_like(self, output)
def perspective(
self,
perspective_coeffs: List[float],
interpolation: InterpolationMode = InterpolationMode.BILINEAR,
fill: FillTypeJIT = None,
) -> Video:
output = self._F._geometry.perspective_video(self, perspective_coeffs, interpolation=interpolation, fill=fill)
return Video.new_like(self, output)
def elastic(
self,
displacement: torch.Tensor,
interpolation: InterpolationMode = InterpolationMode.BILINEAR,
fill: FillTypeJIT = None,
) -> Video:
output = self._F._geometry.elastic_video(self, displacement, interpolation=interpolation, fill=fill)
return Video.new_like(self, output)
def adjust_brightness(self, brightness_factor: float) -> Video:
output = self._F.adjust_brightness_video(self, brightness_factor=brightness_factor)
return Video.new_like(self, output)
def adjust_saturation(self, saturation_factor: float) -> Video:
output = self._F.adjust_saturation_video(self, saturation_factor=saturation_factor)
return Video.new_like(self, output)
def adjust_contrast(self, contrast_factor: float) -> Video:
output = self._F.adjust_contrast_video(self, contrast_factor=contrast_factor)
return Video.new_like(self, output)
def adjust_sharpness(self, sharpness_factor: float) -> Video:
output = self._F.adjust_sharpness_video(self, sharpness_factor=sharpness_factor)
return Video.new_like(self, output)
def adjust_hue(self, hue_factor: float) -> Video:
output = self._F.adjust_hue_video(self, hue_factor=hue_factor)
return Video.new_like(self, output)
def adjust_gamma(self, gamma: float, gain: float = 1) -> Video:
output = self._F.adjust_gamma_video(self, gamma=gamma, gain=gain)
return Video.new_like(self, output)
def posterize(self, bits: int) -> Video:
output = self._F.posterize_video(self, bits=bits)
return Video.new_like(self, output)
def solarize(self, threshold: float) -> Video:
output = self._F.solarize_video(self, threshold=threshold)
return Video.new_like(self, output)
def autocontrast(self) -> Video:
output = self._F.autocontrast_video(self)
return Video.new_like(self, output)
def equalize(self) -> Video:
output = self._F.equalize_video(self)
return Video.new_like(self, output)
def invert(self) -> Video:
output = self._F.invert_video(self)
return Video.new_like(self, output)
def gaussian_blur(self, kernel_size: List[int], sigma: Optional[List[float]] = None) -> Video:
output = self._F.gaussian_blur_video(self, kernel_size=kernel_size, sigma=sigma)
return Video.new_like(self, output)
VideoType = Union[torch.Tensor, Video]
VideoTypeJIT = torch.Tensor
LegacyVideoType = torch.Tensor
LegacyVideoTypeJIT = torch.Tensor
TensorVideoType = Union[torch.Tensor, Video]
TensorVideoTypeJIT = torch.Tensor
ImageOrVideoType = Union[ImageType, VideoType]
ImageOrVideoTypeJIT = Union[ImageTypeJIT, VideoTypeJIT]
TensorImageOrVideoType = Union[TensorImageType, TensorVideoType]
TensorImageOrVideoTypeJIT = Union[TensorImageTypeJIT, TensorVideoTypeJIT]
......@@ -15,7 +15,7 @@ from ._utils import has_any, query_chw
class RandomErasing(_RandomApplyTransform):
_transformed_types = (features.is_simple_tensor, features.Image, PIL.Image.Image)
_transformed_types = (features.is_simple_tensor, features.Image, PIL.Image.Image, features.Video)
def __init__(
self,
......@@ -92,7 +92,7 @@ class RandomErasing(_RandomApplyTransform):
return dict(i=i, j=j, h=h, w=w, v=v)
def _transform(self, inpt: features.ImageType, params: Dict[str, Any]) -> features.ImageType:
def _transform(self, inpt: features.ImageOrVideoType, params: Dict[str, Any]) -> features.ImageOrVideoType:
if params["v"] is not None:
inpt = F.erase(inpt, **params, inplace=self.inplace)
......
......@@ -31,40 +31,41 @@ class _AutoAugmentBase(Transform):
key = keys[int(torch.randint(len(keys), ()))]
return key, dct[key]
def _extract_image(
def _extract_image_or_video(
self,
sample: Any,
unsupported_types: Tuple[Type, ...] = (features.BoundingBox, features.Mask),
) -> Tuple[int, features.ImageType]:
) -> Tuple[int, features.ImageOrVideoType]:
sample_flat, _ = tree_flatten(sample)
images = []
image_or_videos = []
for id, inpt in enumerate(sample_flat):
if _isinstance(inpt, (features.Image, PIL.Image.Image, features.is_simple_tensor)):
images.append((id, inpt))
if _isinstance(inpt, (features.Image, PIL.Image.Image, features.is_simple_tensor, features.Video)):
image_or_videos.append((id, inpt))
elif isinstance(inpt, unsupported_types):
raise TypeError(f"Inputs of type {type(inpt).__name__} are not supported by {type(self).__name__}()")
if not images:
if not image_or_videos:
raise TypeError("Found no image in the sample.")
if len(images) > 1:
if len(image_or_videos) > 1:
raise TypeError(
f"Auto augment transformations are only properly defined for a single image, but found {len(images)}."
f"Auto augment transformations are only properly defined for a single image or video, "
f"but found {len(image_or_videos)}."
)
return images[0]
return image_or_videos[0]
def _put_into_sample(self, sample: Any, id: int, item: Any) -> Any:
sample_flat, spec = tree_flatten(sample)
sample_flat[id] = item
return tree_unflatten(sample_flat, spec)
def _apply_image_transform(
def _apply_image_or_video_transform(
self,
image: features.ImageType,
image: features.ImageOrVideoType,
transform_id: str,
magnitude: float,
interpolation: InterpolationMode,
fill: Dict[Type, features.FillType],
) -> features.ImageType:
) -> features.ImageOrVideoType:
fill_ = fill[type(image)]
fill_ = F._geometry._convert_fill_arg(fill_)
......@@ -276,8 +277,8 @@ class AutoAugment(_AutoAugmentBase):
def forward(self, *inputs: Any) -> Any:
sample = inputs if len(inputs) > 1 else inputs[0]
id, image = self._extract_image(sample)
_, height, width = get_chw(image)
id, image_or_video = self._extract_image_or_video(sample)
_, height, width = get_chw(image_or_video)
policy = self._policies[int(torch.randint(len(self._policies), ()))]
......@@ -295,11 +296,11 @@ class AutoAugment(_AutoAugmentBase):
else:
magnitude = 0.0
image = self._apply_image_transform(
image, transform_id, magnitude, interpolation=self.interpolation, fill=self.fill
image_or_video = self._apply_image_or_video_transform(
image_or_video, transform_id, magnitude, interpolation=self.interpolation, fill=self.fill
)
return self._put_into_sample(sample, id, image)
return self._put_into_sample(sample, id, image_or_video)
class RandAugment(_AutoAugmentBase):
......@@ -347,8 +348,8 @@ class RandAugment(_AutoAugmentBase):
def forward(self, *inputs: Any) -> Any:
sample = inputs if len(inputs) > 1 else inputs[0]
id, image = self._extract_image(sample)
_, height, width = get_chw(image)
id, image_or_video = self._extract_image_or_video(sample)
_, height, width = get_chw(image_or_video)
for _ in range(self.num_ops):
transform_id, (magnitudes_fn, signed) = self._get_random_item(self._AUGMENTATION_SPACE)
......@@ -359,11 +360,11 @@ class RandAugment(_AutoAugmentBase):
magnitude *= -1
else:
magnitude = 0.0
image = self._apply_image_transform(
image, transform_id, magnitude, interpolation=self.interpolation, fill=self.fill
image_or_video = self._apply_image_or_video_transform(
image_or_video, transform_id, magnitude, interpolation=self.interpolation, fill=self.fill
)
return self._put_into_sample(sample, id, image)
return self._put_into_sample(sample, id, image_or_video)
class TrivialAugmentWide(_AutoAugmentBase):
......@@ -401,8 +402,8 @@ class TrivialAugmentWide(_AutoAugmentBase):
def forward(self, *inputs: Any) -> Any:
sample = inputs if len(inputs) > 1 else inputs[0]
id, image = self._extract_image(sample)
_, height, width = get_chw(image)
id, image_or_video = self._extract_image_or_video(sample)
_, height, width = get_chw(image_or_video)
transform_id, (magnitudes_fn, signed) = self._get_random_item(self._AUGMENTATION_SPACE)
......@@ -414,10 +415,10 @@ class TrivialAugmentWide(_AutoAugmentBase):
else:
magnitude = 0.0
image = self._apply_image_transform(
image, transform_id, magnitude, interpolation=self.interpolation, fill=self.fill
image_or_video = self._apply_image_or_video_transform(
image_or_video, transform_id, magnitude, interpolation=self.interpolation, fill=self.fill
)
return self._put_into_sample(sample, id, image)
return self._put_into_sample(sample, id, image_or_video)
class AugMix(_AutoAugmentBase):
......@@ -471,27 +472,28 @@ class AugMix(_AutoAugmentBase):
def forward(self, *inputs: Any) -> Any:
sample = inputs if len(inputs) > 1 else inputs[0]
id, orig_image = self._extract_image(sample)
_, height, width = get_chw(orig_image)
id, orig_image_or_video = self._extract_image_or_video(sample)
_, height, width = get_chw(orig_image_or_video)
if isinstance(orig_image, torch.Tensor):
image = orig_image
if isinstance(orig_image_or_video, torch.Tensor):
image_or_video = orig_image_or_video
else: # isinstance(inpt, PIL.Image.Image):
image = F.pil_to_tensor(orig_image)
image_or_video = F.pil_to_tensor(orig_image_or_video)
augmentation_space = self._AUGMENTATION_SPACE if self.all_ops else self._PARTIAL_AUGMENTATION_SPACE
orig_dims = list(image.shape)
batch = image.view([1] * max(4 - image.ndim, 0) + orig_dims)
orig_dims = list(image_or_video.shape)
batch = image_or_video.view([1] * max(4 - image_or_video.ndim, 0) + orig_dims)
batch_dims = [batch.size(0)] + [1] * (batch.ndim - 1)
# Sample the beta weights for combining the original and augmented image. To get Beta, we use a Dirichlet
# with 2 parameters. The 1st column stores the weights of the original and the 2nd the ones of augmented image.
# Sample the beta weights for combining the original and augmented image or video. To get Beta, we use a
# Dirichlet with 2 parameters. The 1st column stores the weights of the original and the 2nd the ones of
# augmented image or video.
m = self._sample_dirichlet(
torch.tensor([self.alpha, self.alpha], device=batch.device).expand(batch_dims[0], -1)
)
# Sample the mixing weights and combine them with the ones sampled from Beta for the augmented images.
# Sample the mixing weights and combine them with the ones sampled from Beta for the augmented images or videos.
combined_weights = self._sample_dirichlet(
torch.tensor([self.alpha] * self.mixture_width, device=batch.device).expand(batch_dims[0], -1)
) * m[:, 1].view([batch_dims[0], -1])
......@@ -511,15 +513,15 @@ class AugMix(_AutoAugmentBase):
else:
magnitude = 0.0
aug = self._apply_image_transform(
aug = self._apply_image_or_video_transform(
aug, transform_id, magnitude, interpolation=self.interpolation, fill=self.fill
)
mix.add_(combined_weights[:, i].view(batch_dims) * aug)
mix = mix.view(orig_dims).to(dtype=image.dtype)
mix = mix.view(orig_dims).to(dtype=image_or_video.dtype)
if isinstance(orig_image, features.Image):
mix = features.Image.new_like(orig_image, mix)
elif isinstance(orig_image, PIL.Image.Image):
if isinstance(orig_image_or_video, (features.Image, features.Video)):
mix = type(orig_image_or_video).new_like(orig_image_or_video, mix) # type: ignore[arg-type]
elif isinstance(orig_image_or_video, PIL.Image.Image):
mix = F.to_image_pil(mix)
return self._put_into_sample(sample, id, mix)
......@@ -82,7 +82,7 @@ class ColorJitter(Transform):
class RandomPhotometricDistort(Transform):
_transformed_types = (features.Image, PIL.Image.Image, features.is_simple_tensor)
_transformed_types = (features.Image, PIL.Image.Image, features.is_simple_tensor, features.Video)
def __init__(
self,
......@@ -110,20 +110,22 @@ class RandomPhotometricDistort(Transform):
channel_permutation=torch.randperm(num_channels) if torch.rand(()) < self.p else None,
)
def _permute_channels(self, inpt: features.ImageType, permutation: torch.Tensor) -> features.ImageType:
def _permute_channels(
self, inpt: features.ImageOrVideoType, permutation: torch.Tensor
) -> features.ImageOrVideoType:
if isinstance(inpt, PIL.Image.Image):
inpt = F.pil_to_tensor(inpt)
output = inpt[..., permutation, :, :]
if isinstance(inpt, features.Image):
output = features.Image.new_like(inpt, output, color_space=features.ColorSpace.OTHER)
if isinstance(inpt, (features.Image, features.Video)):
output = type(inpt).new_like(inpt, output, color_space=features.ColorSpace.OTHER) # type: ignore[arg-type]
elif isinstance(inpt, PIL.Image.Image):
output = F.to_image_pil(output)
return output
def _transform(self, inpt: features.ImageType, params: Dict[str, Any]) -> features.ImageType:
def _transform(self, inpt: features.ImageOrVideoType, params: Dict[str, Any]) -> features.ImageOrVideoType:
if params["brightness"]:
inpt = F.adjust_brightness(
inpt, brightness_factor=ColorJitter._generate_value(self.brightness[0], self.brightness[1])
......
......@@ -855,8 +855,10 @@ class FixedSizeCrop(Transform):
return inpt
def forward(self, *inputs: Any) -> Any:
if not has_any(inputs, PIL.Image.Image, features.Image, features.is_simple_tensor):
raise TypeError(f"{type(self).__name__}() requires input sample to contain an tensor or PIL image.")
if not has_any(inputs, PIL.Image.Image, features.Image, features.is_simple_tensor, features.Video):
raise TypeError(
f"{type(self).__name__}() requires input sample to contain an tensor or PIL image or a Video."
)
if has_any(inputs, features.BoundingBox) and not has_any(inputs, features.Label, features.OneHotLabel):
raise TypeError(
......
......@@ -34,7 +34,7 @@ class ConvertImageDtype(Transform):
class ConvertColorSpace(Transform):
_transformed_types = (features.is_simple_tensor, features.Image, PIL.Image.Image)
_transformed_types = (features.is_simple_tensor, features.Image, PIL.Image.Image, features.Video)
def __init__(
self,
......@@ -54,7 +54,7 @@ class ConvertColorSpace(Transform):
self.copy = copy
def _transform(self, inpt: features.ImageType, params: Dict[str, Any]) -> features.ImageType:
def _transform(self, inpt: features.ImageOrVideoType, params: Dict[str, Any]) -> features.ImageOrVideoType:
return F.convert_color_space(
inpt, color_space=self.color_space, old_color_space=self.old_color_space, copy=self.copy
)
......
......@@ -38,7 +38,7 @@ class Lambda(Transform):
class LinearTransformation(Transform):
_transformed_types = (features.is_simple_tensor, features.Image)
_transformed_types = (features.is_simple_tensor, features.Image, features.Video)
def __init__(self, transformation_matrix: torch.Tensor, mean_vector: torch.Tensor):
super().__init__()
......@@ -68,7 +68,7 @@ class LinearTransformation(Transform):
return super().forward(*inputs)
def _transform(self, inpt: features.TensorImageType, params: Dict[str, Any]) -> torch.Tensor:
def _transform(self, inpt: features.TensorImageOrVideoType, params: Dict[str, Any]) -> torch.Tensor:
# Image instance after linear transformation is not Image anymore due to unknown data range
# Thus we will return Tensor for input Image
......@@ -93,7 +93,7 @@ class LinearTransformation(Transform):
class Normalize(Transform):
_transformed_types = (features.Image, features.is_simple_tensor)
_transformed_types = (features.Image, features.is_simple_tensor, features.Video)
def __init__(self, mean: Sequence[float], std: Sequence[float], inplace: bool = False):
super().__init__()
......@@ -101,7 +101,7 @@ class Normalize(Transform):
self.std = list(std)
self.inplace = inplace
def _transform(self, inpt: features.TensorImageType, params: Dict[str, Any]) -> torch.Tensor:
def _transform(self, inpt: features.TensorImageOrVideoType, params: Dict[str, Any]) -> torch.Tensor:
return F.normalize(inpt, mean=self.mean, std=self.std, inplace=self.inplace)
def forward(self, *inpts: Any) -> Any:
......
......@@ -82,10 +82,10 @@ def query_chw(sample: Any) -> Tuple[int, int, int]:
chws = {
get_chw(item)
for item in flat_sample
if isinstance(item, (features.Image, PIL.Image.Image)) or features.is_simple_tensor(item)
if isinstance(item, (features.Image, PIL.Image.Image, features.Video)) or features.is_simple_tensor(item)
}
if not chws:
raise TypeError("No image was found in the sample")
raise TypeError("No image or video was found in the sample")
elif len(chws) > 1:
raise ValueError(f"Found multiple CxHxW dimensions in the sample: {sequence_to_str(sorted(chws))}")
return chws.pop()
......
......@@ -6,6 +6,7 @@ from ._meta import (
convert_format_bounding_box,
convert_color_space_image_tensor,
convert_color_space_image_pil,
convert_color_space_video,
convert_color_space,
get_dimensions,
get_image_num_channels,
......@@ -13,41 +14,52 @@ from ._meta import (
get_spatial_size,
) # usort: skip
from ._augment import erase, erase_image_pil, erase_image_tensor
from ._augment import erase, erase_image_pil, erase_image_tensor, erase_video
from ._color import (
adjust_brightness,
adjust_brightness_image_pil,
adjust_brightness_image_tensor,
adjust_brightness_video,
adjust_contrast,
adjust_contrast_image_pil,
adjust_contrast_image_tensor,
adjust_contrast_video,
adjust_gamma,
adjust_gamma_image_pil,
adjust_gamma_image_tensor,
adjust_gamma_video,
adjust_hue,
adjust_hue_image_pil,
adjust_hue_image_tensor,
adjust_hue_video,
adjust_saturation,
adjust_saturation_image_pil,
adjust_saturation_image_tensor,
adjust_saturation_video,
adjust_sharpness,
adjust_sharpness_image_pil,
adjust_sharpness_image_tensor,
adjust_sharpness_video,
autocontrast,
autocontrast_image_pil,
autocontrast_image_tensor,
autocontrast_video,
equalize,
equalize_image_pil,
equalize_image_tensor,
equalize_video,
invert,
invert_image_pil,
invert_image_tensor,
invert_video,
posterize,
posterize_image_pil,
posterize_image_tensor,
posterize_video,
solarize,
solarize_image_pil,
solarize_image_tensor,
solarize_video,
)
from ._geometry import (
affine,
......@@ -55,22 +67,26 @@ from ._geometry import (
affine_image_pil,
affine_image_tensor,
affine_mask,
affine_video,
center_crop,
center_crop_bounding_box,
center_crop_image_pil,
center_crop_image_tensor,
center_crop_mask,
center_crop_video,
crop,
crop_bounding_box,
crop_image_pil,
crop_image_tensor,
crop_mask,
crop_video,
elastic,
elastic_bounding_box,
elastic_image_pil,
elastic_image_tensor,
elastic_mask,
elastic_transform,
elastic_video,
five_crop,
five_crop_image_pil,
five_crop_image_tensor,
......@@ -80,31 +96,37 @@ from ._geometry import (
horizontal_flip_image_pil,
horizontal_flip_image_tensor,
horizontal_flip_mask,
horizontal_flip_video,
pad,
pad_bounding_box,
pad_image_pil,
pad_image_tensor,
pad_mask,
pad_video,
perspective,
perspective_bounding_box,
perspective_image_pil,
perspective_image_tensor,
perspective_mask,
perspective_video,
resize,
resize_bounding_box,
resize_image_pil,
resize_image_tensor,
resize_mask,
resize_video,
resized_crop,
resized_crop_bounding_box,
resized_crop_image_pil,
resized_crop_image_tensor,
resized_crop_mask,
resized_crop_video,
rotate,
rotate_bounding_box,
rotate_image_pil,
rotate_image_tensor,
rotate_mask,
rotate_video,
ten_crop,
ten_crop_image_pil,
ten_crop_image_tensor,
......@@ -113,9 +135,18 @@ from ._geometry import (
vertical_flip_image_pil,
vertical_flip_image_tensor,
vertical_flip_mask,
vertical_flip_video,
vflip,
)
from ._misc import gaussian_blur, gaussian_blur_image_pil, gaussian_blur_image_tensor, normalize, normalize_image_tensor
from ._misc import (
gaussian_blur,
gaussian_blur_image_pil,
gaussian_blur_image_tensor,
gaussian_blur_video,
normalize,
normalize_image_tensor,
normalize_video,
)
from ._type_conversion import (
convert_image_dtype,
decode_image_with_pil,
......
......@@ -17,19 +17,25 @@ def erase_image_pil(
return to_pil_image(output, mode=image.mode)
def erase_video(
video: torch.Tensor, i: int, j: int, h: int, w: int, v: torch.Tensor, inplace: bool = False
) -> torch.Tensor:
return erase_image_tensor(video, i=i, j=j, h=h, w=w, v=v, inplace=inplace)
def erase(
inpt: features.ImageTypeJIT,
inpt: features.ImageOrVideoTypeJIT,
i: int,
j: int,
h: int,
w: int,
v: torch.Tensor,
inplace: bool = False,
) -> features.ImageTypeJIT:
) -> features.ImageOrVideoTypeJIT:
if isinstance(inpt, torch.Tensor):
output = erase_image_tensor(inpt, i=i, j=j, h=h, w=w, v=v, inplace=inplace)
if not torch.jit.is_scripting() and isinstance(inpt, features.Image):
output = features.Image.new_like(inpt, output)
if not torch.jit.is_scripting() and isinstance(inpt, (features.Image, features.Video)):
output = type(inpt).new_like(inpt, output) # type: ignore[arg-type]
return output
else: # isinstance(inpt, PIL.Image.Image):
return erase_image_pil(inpt, i=i, j=j, h=h, w=w, v=v, inplace=inplace)
......@@ -2,10 +2,16 @@ import torch
from torchvision.prototype import features
from torchvision.transforms import functional_pil as _FP, functional_tensor as _FT
from ._meta import get_dimensions_image_tensor
adjust_brightness_image_tensor = _FT.adjust_brightness
adjust_brightness_image_pil = _FP.adjust_brightness
def adjust_brightness_video(video: torch.Tensor, brightness_factor: float) -> torch.Tensor:
return adjust_brightness_image_tensor(video, brightness_factor=brightness_factor)
def adjust_brightness(inpt: features.InputTypeJIT, brightness_factor: float) -> features.InputTypeJIT:
if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, features._Feature)):
return adjust_brightness_image_tensor(inpt, brightness_factor=brightness_factor)
......@@ -19,6 +25,10 @@ adjust_saturation_image_tensor = _FT.adjust_saturation
adjust_saturation_image_pil = _FP.adjust_saturation
def adjust_saturation_video(video: torch.Tensor, saturation_factor: float) -> torch.Tensor:
return adjust_saturation_image_tensor(video, saturation_factor=saturation_factor)
def adjust_saturation(inpt: features.InputTypeJIT, saturation_factor: float) -> features.InputTypeJIT:
if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, features._Feature)):
return adjust_saturation_image_tensor(inpt, saturation_factor=saturation_factor)
......@@ -32,6 +42,10 @@ adjust_contrast_image_tensor = _FT.adjust_contrast
adjust_contrast_image_pil = _FP.adjust_contrast
def adjust_contrast_video(video: torch.Tensor, contrast_factor: float) -> torch.Tensor:
return adjust_contrast_image_tensor(video, contrast_factor=contrast_factor)
def adjust_contrast(inpt: features.InputTypeJIT, contrast_factor: float) -> features.InputTypeJIT:
if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, features._Feature)):
return adjust_contrast_image_tensor(inpt, contrast_factor=contrast_factor)
......@@ -41,10 +55,40 @@ def adjust_contrast(inpt: features.InputTypeJIT, contrast_factor: float) -> feat
return adjust_contrast_image_pil(inpt, contrast_factor=contrast_factor)
adjust_sharpness_image_tensor = _FT.adjust_sharpness
def adjust_sharpness_image_tensor(image: torch.Tensor, sharpness_factor: float) -> torch.Tensor:
num_channels, height, width = get_dimensions_image_tensor(image)
if num_channels not in (1, 3):
raise TypeError(f"Input image tensor can have 1 or 3 channels, but found {num_channels}")
if sharpness_factor < 0:
raise ValueError(f"sharpness_factor ({sharpness_factor}) is not non-negative.")
if image.numel() == 0 or height <= 2 or width <= 2:
return image
shape = image.shape
if image.ndim > 4:
image = image.view(-1, num_channels, height, width)
needs_unsquash = True
else:
needs_unsquash = False
output = _FT._blend(image, _FT._blurred_degenerate_image(image), sharpness_factor)
if needs_unsquash:
output = output.view(shape)
return output
adjust_sharpness_image_pil = _FP.adjust_sharpness
def adjust_sharpness_video(video: torch.Tensor, sharpness_factor: float) -> torch.Tensor:
return adjust_sharpness_image_tensor(video, sharpness_factor=sharpness_factor)
def adjust_sharpness(inpt: features.InputTypeJIT, sharpness_factor: float) -> features.InputTypeJIT:
if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, features._Feature)):
return adjust_sharpness_image_tensor(inpt, sharpness_factor=sharpness_factor)
......@@ -58,6 +102,10 @@ adjust_hue_image_tensor = _FT.adjust_hue
adjust_hue_image_pil = _FP.adjust_hue
def adjust_hue_video(video: torch.Tensor, hue_factor: float) -> torch.Tensor:
return adjust_hue_image_tensor(video, hue_factor=hue_factor)
def adjust_hue(inpt: features.InputTypeJIT, hue_factor: float) -> features.InputTypeJIT:
if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, features._Feature)):
return adjust_hue_image_tensor(inpt, hue_factor=hue_factor)
......@@ -71,6 +119,10 @@ adjust_gamma_image_tensor = _FT.adjust_gamma
adjust_gamma_image_pil = _FP.adjust_gamma
def adjust_gamma_video(video: torch.Tensor, gamma: float, gain: float = 1) -> torch.Tensor:
return adjust_gamma_image_tensor(video, gamma=gamma, gain=gain)
def adjust_gamma(inpt: features.InputTypeJIT, gamma: float, gain: float = 1) -> features.InputTypeJIT:
if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, features._Feature)):
return adjust_gamma_image_tensor(inpt, gamma=gamma, gain=gain)
......@@ -84,6 +136,10 @@ posterize_image_tensor = _FT.posterize
posterize_image_pil = _FP.posterize
def posterize_video(video: torch.Tensor, bits: int) -> torch.Tensor:
return posterize_image_tensor(video, bits=bits)
def posterize(inpt: features.InputTypeJIT, bits: int) -> features.InputTypeJIT:
if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, features._Feature)):
return posterize_image_tensor(inpt, bits=bits)
......@@ -97,6 +153,10 @@ solarize_image_tensor = _FT.solarize
solarize_image_pil = _FP.solarize
def solarize_video(video: torch.Tensor, threshold: float) -> torch.Tensor:
return solarize_image_tensor(video, threshold=threshold)
def solarize(inpt: features.InputTypeJIT, threshold: float) -> features.InputTypeJIT:
if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, features._Feature)):
return solarize_image_tensor(inpt, threshold=threshold)
......@@ -110,6 +170,10 @@ autocontrast_image_tensor = _FT.autocontrast
autocontrast_image_pil = _FP.autocontrast
def autocontrast_video(video: torch.Tensor) -> torch.Tensor:
return autocontrast_image_tensor(video)
def autocontrast(inpt: features.InputTypeJIT) -> features.InputTypeJIT:
if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, features._Feature)):
return autocontrast_image_tensor(inpt)
......@@ -119,10 +183,35 @@ def autocontrast(inpt: features.InputTypeJIT) -> features.InputTypeJIT:
return autocontrast_image_pil(inpt)
equalize_image_tensor = _FT.equalize
def equalize_image_tensor(image: torch.Tensor) -> torch.Tensor:
if image.dtype != torch.uint8:
raise TypeError(f"Only torch.uint8 image tensors are supported, but found {image.dtype}")
num_channels, height, width = get_dimensions_image_tensor(image)
if num_channels not in (1, 3):
raise TypeError(f"Input image tensor can have 1 or 3 channels, but found {num_channels}")
if image.numel() == 0:
return image
elif image.ndim == 2:
return _FT._scale_channel(image)
else:
return torch.stack(
[
# TODO: when merging transforms v1 and v2, we can inline this function call
_FT._equalize_single_image(single_image)
for single_image in image.view(-1, num_channels, height, width)
]
).view(image.shape)
equalize_image_pil = _FP.equalize
def equalize_video(video: torch.Tensor) -> torch.Tensor:
return equalize_image_tensor(video)
def equalize(inpt: features.InputTypeJIT) -> features.InputTypeJIT:
if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, features._Feature)):
return equalize_image_tensor(inpt)
......@@ -136,6 +225,10 @@ invert_image_tensor = _FT.invert
invert_image_pil = _FP.invert
def invert_video(video: torch.Tensor) -> torch.Tensor:
return invert_image_tensor(video)
def invert(inpt: features.InputTypeJIT) -> features.InputTypeJIT:
if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, features._Feature)):
return invert_image_tensor(inpt)
......
......@@ -47,6 +47,10 @@ def horizontal_flip_bounding_box(
).view(shape)
def horizontal_flip_video(video: torch.Tensor) -> torch.Tensor:
return horizontal_flip_image_tensor(video)
def horizontal_flip(inpt: features.InputTypeJIT) -> features.InputTypeJIT:
if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, features._Feature)):
return horizontal_flip_image_tensor(inpt)
......@@ -80,6 +84,10 @@ def vertical_flip_bounding_box(
).view(shape)
def vertical_flip_video(video: torch.Tensor) -> torch.Tensor:
return vertical_flip_image_tensor(video)
def vertical_flip(inpt: features.InputTypeJIT) -> features.InputTypeJIT:
if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, features._Feature)):
return vertical_flip_image_tensor(inpt)
......@@ -185,6 +193,16 @@ def resize_bounding_box(
)
def resize_video(
video: torch.Tensor,
size: List[int],
interpolation: InterpolationMode = InterpolationMode.BILINEAR,
max_size: Optional[int] = None,
antialias: bool = False,
) -> torch.Tensor:
return resize_image_tensor(video, size=size, interpolation=interpolation, max_size=max_size, antialias=antialias)
def resize(
inpt: features.InputTypeJIT,
size: List[int],
......@@ -441,6 +459,28 @@ def affine_mask(
return output
def affine_video(
video: torch.Tensor,
angle: Union[int, float],
translate: List[float],
scale: float,
shear: List[float],
interpolation: InterpolationMode = InterpolationMode.NEAREST,
fill: features.FillTypeJIT = None,
center: Optional[List[float]] = None,
) -> torch.Tensor:
return affine_image_tensor(
video,
angle=angle,
translate=translate,
scale=scale,
shear=shear,
interpolation=interpolation,
fill=fill,
center=center,
)
def _convert_fill_arg(fill: features.FillType) -> features.FillTypeJIT:
# Fill = 0 is not equivalent to None, https://github.com/pytorch/vision/issues/6517
# So, we can't reassign fill to 0
......@@ -614,6 +654,17 @@ def rotate_mask(
return output
def rotate_video(
video: torch.Tensor,
angle: float,
interpolation: InterpolationMode = InterpolationMode.NEAREST,
expand: bool = False,
fill: features.FillTypeJIT = None,
center: Optional[List[float]] = None,
) -> torch.Tensor:
return rotate_image_tensor(video, angle, interpolation=interpolation, expand=expand, fill=fill, center=center)
def rotate(
inpt: features.InputTypeJIT,
angle: float,
......@@ -751,6 +802,15 @@ def pad_bounding_box(
return bounding_box, (height, width)
def pad_video(
video: torch.Tensor,
padding: Union[int, List[int]],
fill: features.FillTypeJIT = None,
padding_mode: str = "constant",
) -> torch.Tensor:
return pad_image_tensor(video, padding, fill=fill, padding_mode=padding_mode)
def pad(
inpt: features.InputTypeJIT,
padding: Union[int, List[int]],
......@@ -798,6 +858,10 @@ def crop_mask(mask: torch.Tensor, top: int, left: int, height: int, width: int)
return crop_image_tensor(mask, top, left, height, width)
def crop_video(video: torch.Tensor, top: int, left: int, height: int, width: int) -> torch.Tensor:
return crop_image_tensor(video, top, left, height, width)
def crop(inpt: features.InputTypeJIT, top: int, left: int, height: int, width: int) -> features.InputTypeJIT:
if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, features._Feature)):
return crop_image_tensor(inpt, top, left, height, width)
......@@ -932,6 +996,33 @@ def perspective_mask(
return output
def perspective_video(
video: torch.Tensor,
perspective_coeffs: List[float],
interpolation: InterpolationMode = InterpolationMode.BILINEAR,
fill: features.FillTypeJIT = None,
) -> torch.Tensor:
# TODO: this is a temporary workaround until the image kernel supports arbitrary batch sizes. Remove this when
# https://github.com/pytorch/vision/issues/6670 is resolved.
if video.numel() == 0:
return video
shape = video.shape
if video.ndim > 4:
video = video.view((-1,) + shape[-3:])
needs_unsquash = True
else:
needs_unsquash = False
output = perspective_image_tensor(video, perspective_coeffs, interpolation=interpolation, fill=fill)
if needs_unsquash:
output = output.view(shape)
return output
def perspective(
inpt: features.InputTypeJIT,
perspective_coeffs: List[float],
......@@ -1026,6 +1117,33 @@ def elastic_mask(
return output
def elastic_video(
video: torch.Tensor,
displacement: torch.Tensor,
interpolation: InterpolationMode = InterpolationMode.BILINEAR,
fill: features.FillTypeJIT = None,
) -> torch.Tensor:
# TODO: this is a temporary workaround until the image kernel supports arbitrary batch sizes. Remove this when
# https://github.com/pytorch/vision/issues/6670 is resolved.
if video.numel() == 0:
return video
shape = video.shape
if video.ndim > 4:
video = video.view((-1,) + shape[-3:])
needs_unsquash = True
else:
needs_unsquash = False
output = elastic_image_tensor(video, displacement, interpolation=interpolation, fill=fill)
if needs_unsquash:
output = output.view(shape)
return output
def elastic(
inpt: features.InputTypeJIT,
displacement: torch.Tensor,
......@@ -1128,6 +1246,10 @@ def center_crop_mask(mask: torch.Tensor, output_size: List[int]) -> torch.Tensor
return output
def center_crop_video(video: torch.Tensor, output_size: List[int]) -> torch.Tensor:
return center_crop_image_tensor(video, output_size)
def center_crop(inpt: features.InputTypeJIT, output_size: List[int]) -> features.InputTypeJIT:
if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, features._Feature)):
return center_crop_image_tensor(inpt, output_size)
......@@ -1190,6 +1312,21 @@ def resized_crop_mask(
return resize_mask(mask, size)
def resized_crop_video(
video: torch.Tensor,
top: int,
left: int,
height: int,
width: int,
size: List[int],
interpolation: InterpolationMode = InterpolationMode.BILINEAR,
antialias: bool = False,
) -> torch.Tensor:
return resized_crop_image_tensor(
video, top, left, height, width, antialias=antialias, size=size, interpolation=interpolation
)
def resized_crop(
inpt: features.InputTypeJIT,
top: int,
......
......@@ -11,10 +11,12 @@ get_dimensions_image_pil = _FP.get_dimensions
# TODO: Should this be prefixed with `_` similar to other methods that don't get exposed by init?
def get_chw(image: features.ImageTypeJIT) -> Tuple[int, int, int]:
if isinstance(image, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(image, features.Image)):
def get_chw(image: features.ImageOrVideoTypeJIT) -> Tuple[int, int, int]:
if isinstance(image, torch.Tensor) and (
torch.jit.is_scripting() or not isinstance(image, (features.Image, features.Video))
):
channels, height, width = get_dimensions_image_tensor(image)
elif isinstance(image, features.Image):
elif isinstance(image, (features.Image, features.Video)):
channels = image.num_channels
height, width = image.image_size
else: # isinstance(image, PIL.Image.Image)
......@@ -29,11 +31,11 @@ def get_chw(image: features.ImageTypeJIT) -> Tuple[int, int, int]:
# detailed above.
def get_dimensions(image: features.ImageTypeJIT) -> List[int]:
def get_dimensions(image: features.ImageOrVideoTypeJIT) -> List[int]:
return list(get_chw(image))
def get_num_channels(image: features.ImageTypeJIT) -> int:
def get_num_channels(image: features.ImageOrVideoTypeJIT) -> int:
num_channels, *_ = get_chw(image)
return num_channels
......@@ -43,7 +45,7 @@ def get_num_channels(image: features.ImageTypeJIT) -> int:
get_image_num_channels = get_num_channels
def get_spatial_size(image: features.ImageTypeJIT) -> List[int]:
def get_spatial_size(image: features.ImageOrVideoTypeJIT) -> List[int]:
_, *size = get_chw(image)
return size
......@@ -207,13 +209,23 @@ def convert_color_space_image_pil(
return image.convert(new_mode)
def convert_color_space_video(
video: torch.Tensor, old_color_space: ColorSpace, new_color_space: ColorSpace, copy: bool = True
) -> torch.Tensor:
return convert_color_space_image_tensor(
video, old_color_space=old_color_space, new_color_space=new_color_space, copy=copy
)
def convert_color_space(
inpt: features.ImageTypeJIT,
inpt: features.ImageOrVideoTypeJIT,
color_space: ColorSpace,
old_color_space: Optional[ColorSpace] = None,
copy: bool = True,
) -> features.ImageTypeJIT:
if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, features.Image)):
) -> features.ImageOrVideoTypeJIT:
if isinstance(inpt, torch.Tensor) and (
torch.jit.is_scripting() or not isinstance(inpt, (features.Image, features.Video))
):
if old_color_space is None:
raise RuntimeError(
"In order to convert the color space of simple tensor images, "
......@@ -222,7 +234,7 @@ def convert_color_space(
return convert_color_space_image_tensor(
inpt, old_color_space=old_color_space, new_color_space=color_space, copy=copy
)
elif isinstance(inpt, features.Image):
elif isinstance(inpt, (features.Image, features.Video)):
return inpt.to_color_space(color_space, copy=copy)
else:
return cast(features.ImageTypeJIT, convert_color_space_image_pil(inpt, color_space, copy=copy))
return cast(features.ImageOrVideoTypeJIT, convert_color_space_image_pil(inpt, color_space, copy=copy))
......@@ -9,18 +9,22 @@ from torchvision.transforms.functional import pil_to_tensor, to_pil_image
normalize_image_tensor = _FT.normalize
def normalize_video(video: torch.Tensor, mean: List[float], std: List[float], inplace: bool = False) -> torch.Tensor:
return normalize_image_tensor(video, mean, std, inplace=inplace)
def normalize(
inpt: features.TensorImageTypeJIT, mean: List[float], std: List[float], inplace: bool = False
inpt: features.TensorImageOrVideoTypeJIT, mean: List[float], std: List[float], inplace: bool = False
) -> torch.Tensor:
if torch.jit.is_scripting():
correct_type = isinstance(inpt, torch.Tensor)
else:
correct_type = features.is_simple_tensor(inpt) or isinstance(inpt, features.Image)
correct_type = features.is_simple_tensor(inpt) or isinstance(inpt, (features.Image, features.Video))
inpt = inpt.as_subclass(torch.Tensor)
if not correct_type:
raise TypeError(f"img should be Tensor Image. Got {type(inpt)}")
# Image instance after normalization is not Image anymore due to unknown data range
# Image or Video type should not be retained after normalization due to unknown data range
# Thus we return Tensor for input Image
return normalize_image_tensor(inpt, mean=mean, std=std, inplace=inplace)
......@@ -64,6 +68,30 @@ def gaussian_blur_image_pil(
return to_pil_image(output, mode=image.mode)
def gaussian_blur_video(
video: torch.Tensor, kernel_size: List[int], sigma: Optional[List[float]] = None
) -> torch.Tensor:
# TODO: this is a temporary workaround until the image kernel supports arbitrary batch sizes. Remove this when
# https://github.com/pytorch/vision/issues/6670 is resolved.
if video.numel() == 0:
return video
shape = video.shape
if video.ndim > 4:
video = video.view((-1,) + shape[-3:])
needs_unsquash = True
else:
needs_unsquash = False
output = gaussian_blur_image_tensor(video, kernel_size, sigma)
if needs_unsquash:
output = output.view(shape)
return output
def gaussian_blur(
inpt: features.InputTypeJIT, kernel_size: List[int], sigma: Optional[List[float]] = None
) -> features.InputTypeJIT:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment