"examples/mxnet/_deprecated/sampling/dis_sampling/sampler.py" did not exist on "688a9228a820c419d9548ea2b44a6e4fe0a2cc1e"
Unverified Commit 3118fb52 authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

add Video feature and kernels (#6667)

* add video feature

* add video kernels

* add video testing utils

* add one kernel info

* fix kernel names in Video feature

* use only uint8 for video testing

* require at least 4 dims for Video feature

* add TODO for image_size -> spatial_size

* image -> video in feature constructor

* introduce new combined images and video type

* add video to transform utils

* fix transforms test

* fix auto augment

* cleanup

* address review comments

* add remaining video kernel infos

* add batch dimension squashing to some kernels

* fix tests and kernel infos

* add xfails for arbitrary batch sizes on some kernels

* fix test setup

* fix equalize_image_tensor for multi batch dims

* fix adjust_sharpness_image_tensor for multi batch dims

* address review comments
parent 7eb5d7fc
...@@ -45,6 +45,8 @@ __all__ = [ ...@@ -45,6 +45,8 @@ __all__ = [
"make_segmentation_masks", "make_segmentation_masks",
"make_mask_loaders", "make_mask_loaders",
"make_masks", "make_masks",
"make_video",
"make_videos",
] ]
...@@ -210,17 +212,19 @@ DEFAULT_EXTRA_DIMS = (*VALID_EXTRA_DIMS, *DEGENERATE_BATCH_DIMS) ...@@ -210,17 +212,19 @@ DEFAULT_EXTRA_DIMS = (*VALID_EXTRA_DIMS, *DEGENERATE_BATCH_DIMS)
def from_loader(loader_fn): def from_loader(loader_fn):
def wrapper(*args, **kwargs): def wrapper(*args, **kwargs):
device = kwargs.pop("device", "cpu")
loader = loader_fn(*args, **kwargs) loader = loader_fn(*args, **kwargs)
return loader.load(kwargs.get("device", "cpu")) return loader.load(device)
return wrapper return wrapper
def from_loaders(loaders_fn): def from_loaders(loaders_fn):
def wrapper(*args, **kwargs): def wrapper(*args, **kwargs):
device = kwargs.pop("device", "cpu")
loaders = loaders_fn(*args, **kwargs) loaders = loaders_fn(*args, **kwargs)
for loader in loaders: for loader in loaders:
yield loader.load(kwargs.get("device", "cpu")) yield loader.load(device)
return wrapper return wrapper
...@@ -246,6 +250,21 @@ class ImageLoader(TensorLoader): ...@@ -246,6 +250,21 @@ class ImageLoader(TensorLoader):
self.num_channels = self.shape[-3] self.num_channels = self.shape[-3]
NUM_CHANNELS_MAP = {
features.ColorSpace.GRAY: 1,
features.ColorSpace.GRAY_ALPHA: 2,
features.ColorSpace.RGB: 3,
features.ColorSpace.RGB_ALPHA: 4,
}
def get_num_channels(color_space):
num_channels = NUM_CHANNELS_MAP.get(color_space)
if not num_channels:
raise pytest.UsageError(f"Can't determine the number of channels for color space {color_space}")
return num_channels
def make_image_loader( def make_image_loader(
size="random", size="random",
*, *,
...@@ -255,16 +274,7 @@ def make_image_loader( ...@@ -255,16 +274,7 @@ def make_image_loader(
constant_alpha=True, constant_alpha=True,
): ):
size = _parse_image_size(size) size = _parse_image_size(size)
num_channels = get_num_channels(color_space)
try:
num_channels = {
features.ColorSpace.GRAY: 1,
features.ColorSpace.GRAY_ALPHA: 2,
features.ColorSpace.RGB: 3,
features.ColorSpace.RGB_ALPHA: 4,
}[color_space]
except KeyError as error:
raise pytest.UsageError(f"Can't determine the number of channels for color space {color_space}") from error
def fn(shape, dtype, device): def fn(shape, dtype, device):
max_value = get_max_value(dtype) max_value = get_max_value(dtype)
...@@ -531,3 +541,50 @@ def make_mask_loaders( ...@@ -531,3 +541,50 @@ def make_mask_loaders(
make_masks = from_loaders(make_mask_loaders) make_masks = from_loaders(make_mask_loaders)
class VideoLoader(ImageLoader):
pass
def make_video_loader(
size="random",
*,
color_space=features.ColorSpace.RGB,
num_frames="random",
extra_dims=(),
dtype=torch.uint8,
):
size = _parse_image_size(size)
num_frames = int(torch.randint(1, 5, ())) if num_frames == "random" else num_frames
def fn(shape, dtype, device):
video = make_image(size=shape[-2:], color_space=color_space, extra_dims=shape[:-3], dtype=dtype, device=device)
return features.Video(video, color_space=color_space)
return VideoLoader(
fn, shape=(*extra_dims, num_frames, get_num_channels(color_space), *size), dtype=dtype, color_space=color_space
)
make_video = from_loader(make_video_loader)
def make_video_loaders(
*,
sizes=DEFAULT_IMAGE_SIZES,
color_spaces=(
features.ColorSpace.GRAY,
features.ColorSpace.RGB,
),
num_frames=(1, 0, "random"),
extra_dims=DEFAULT_EXTRA_DIMS,
dtypes=(torch.uint8,),
):
for params in combinations_grid(
size=sizes, color_space=color_spaces, num_frames=num_frames, extra_dims=extra_dims, dtype=dtypes
):
yield make_video_loader(**params)
make_videos = from_loaders(make_video_loaders)
...@@ -127,6 +127,23 @@ xfail_dispatch_pil_if_fill_sequence_needs_broadcast = TestMark( ...@@ -127,6 +127,23 @@ xfail_dispatch_pil_if_fill_sequence_needs_broadcast = TestMark(
) )
def xfail_all_tests(*, reason, condition):
return [
TestMark(("TestDispatchers", test_name), pytest.mark.xfail(reason=reason), condition=condition)
for test_name in [
"test_scripted_smoke",
"test_dispatch_simple_tensor",
"test_dispatch_feature",
]
]
xfails_degenerate_or_multi_batch_dims = xfail_all_tests(
reason="See https://github.com/pytorch/vision/issues/6670 for details.",
condition=lambda args_kwargs: len(args_kwargs.args[0].shape) > 4 or not all(args_kwargs.args[0].shape[:-3]),
)
DISPATCHER_INFOS = [ DISPATCHER_INFOS = [
DispatcherInfo( DispatcherInfo(
F.horizontal_flip, F.horizontal_flip,
...@@ -243,6 +260,7 @@ DISPATCHER_INFOS = [ ...@@ -243,6 +260,7 @@ DISPATCHER_INFOS = [
pil_kernel_info=PILKernelInfo(F.perspective_image_pil), pil_kernel_info=PILKernelInfo(F.perspective_image_pil),
test_marks=[ test_marks=[
xfail_dispatch_pil_if_fill_sequence_needs_broadcast, xfail_dispatch_pil_if_fill_sequence_needs_broadcast,
*xfails_degenerate_or_multi_batch_dims,
], ],
), ),
DispatcherInfo( DispatcherInfo(
...@@ -253,6 +271,7 @@ DISPATCHER_INFOS = [ ...@@ -253,6 +271,7 @@ DISPATCHER_INFOS = [
features.Mask: F.elastic_mask, features.Mask: F.elastic_mask,
}, },
pil_kernel_info=PILKernelInfo(F.elastic_image_pil), pil_kernel_info=PILKernelInfo(F.elastic_image_pil),
test_marks=xfails_degenerate_or_multi_batch_dims,
), ),
DispatcherInfo( DispatcherInfo(
F.center_crop, F.center_crop,
...@@ -275,6 +294,7 @@ DISPATCHER_INFOS = [ ...@@ -275,6 +294,7 @@ DISPATCHER_INFOS = [
test_marks=[ test_marks=[
xfail_jit_python_scalar_arg("kernel_size"), xfail_jit_python_scalar_arg("kernel_size"),
xfail_jit_python_scalar_arg("sigma"), xfail_jit_python_scalar_arg("sigma"),
*xfails_degenerate_or_multi_batch_dims,
], ],
), ),
DispatcherInfo( DispatcherInfo(
......
...@@ -20,6 +20,7 @@ from prototype_common_utils import ( ...@@ -20,6 +20,7 @@ from prototype_common_utils import (
make_image_loader, make_image_loader,
make_image_loaders, make_image_loaders,
make_mask_loaders, make_mask_loaders,
make_video_loaders,
VALID_EXTRA_DIMS, VALID_EXTRA_DIMS,
) )
from torchvision.prototype import features from torchvision.prototype import features
...@@ -142,6 +143,25 @@ def xfail_jit_list_of_ints(name, *, reason=None): ...@@ -142,6 +143,25 @@ def xfail_jit_list_of_ints(name, *, reason=None):
) )
def xfail_all_tests(*, reason, condition):
return [
TestMark(("TestKernels", test_name), pytest.mark.xfail(reason=reason), condition=condition)
for test_name in [
"test_scripted_vs_eager",
"test_batched_vs_single",
"test_no_inplace",
"test_cuda_vs_cpu",
"test_dtype_and_device_consistency",
]
]
xfails_image_degenerate_or_multi_batch_dims = xfail_all_tests(
reason="See https://github.com/pytorch/vision/issues/6670 for details.",
condition=lambda args_kwargs: len(args_kwargs.args[0].shape) > 4 or not all(args_kwargs.args[0].shape[:-3]),
)
KERNEL_INFOS = [] KERNEL_INFOS = []
...@@ -169,6 +189,11 @@ def sample_inputs_horizontal_flip_mask(): ...@@ -169,6 +189,11 @@ def sample_inputs_horizontal_flip_mask():
yield ArgsKwargs(image_loader) yield ArgsKwargs(image_loader)
def sample_inputs_horizontal_flip_video():
for video_loader in make_video_loaders(sizes=["random"], num_frames=["random"]):
yield ArgsKwargs(video_loader)
KERNEL_INFOS.extend( KERNEL_INFOS.extend(
[ [
KernelInfo( KernelInfo(
...@@ -187,6 +212,10 @@ KERNEL_INFOS.extend( ...@@ -187,6 +212,10 @@ KERNEL_INFOS.extend(
F.horizontal_flip_mask, F.horizontal_flip_mask,
sample_inputs_fn=sample_inputs_horizontal_flip_mask, sample_inputs_fn=sample_inputs_horizontal_flip_mask,
), ),
KernelInfo(
F.horizontal_flip_video,
sample_inputs_fn=sample_inputs_horizontal_flip_video,
),
] ]
) )
...@@ -287,6 +316,11 @@ def reference_inputs_resize_mask(): ...@@ -287,6 +316,11 @@ def reference_inputs_resize_mask():
yield ArgsKwargs(mask_loader, size=size) yield ArgsKwargs(mask_loader, size=size)
def sample_inputs_resize_video():
for video_loader in make_video_loaders(sizes=["random"], num_frames=["random"]):
yield ArgsKwargs(video_loader, size=[min(video_loader.shape[-2:]) + 1])
KERNEL_INFOS.extend( KERNEL_INFOS.extend(
[ [
KernelInfo( KernelInfo(
...@@ -316,6 +350,10 @@ KERNEL_INFOS.extend( ...@@ -316,6 +350,10 @@ KERNEL_INFOS.extend(
xfail_jit_integer_size(), xfail_jit_integer_size(),
], ],
), ),
KernelInfo(
F.resize_video,
sample_inputs_fn=sample_inputs_resize_video,
),
] ]
) )
...@@ -485,7 +523,7 @@ def reference_inputs_affine_bounding_box(): ...@@ -485,7 +523,7 @@ def reference_inputs_affine_bounding_box():
) )
def sample_inputs_affine_image_mask(): def sample_inputs_affine_mask():
for mask_loader in make_mask_loaders(sizes=["random"], num_categories=["random"], num_objects=["random"]): for mask_loader in make_mask_loaders(sizes=["random"], num_categories=["random"], num_objects=["random"]):
yield ArgsKwargs(mask_loader, **_full_affine_params()) yield ArgsKwargs(mask_loader, **_full_affine_params())
...@@ -502,6 +540,11 @@ def reference_inputs_resize_mask(): ...@@ -502,6 +540,11 @@ def reference_inputs_resize_mask():
yield ArgsKwargs(mask_loader, **affine_kwargs) yield ArgsKwargs(mask_loader, **affine_kwargs)
def sample_inputs_affine_video():
for video_loader in make_video_loaders(sizes=["random"], num_frames=["random"]):
yield ArgsKwargs(video_loader, **_full_affine_params())
KERNEL_INFOS.extend( KERNEL_INFOS.extend(
[ [
KernelInfo( KernelInfo(
...@@ -529,7 +572,7 @@ KERNEL_INFOS.extend( ...@@ -529,7 +572,7 @@ KERNEL_INFOS.extend(
), ),
KernelInfo( KernelInfo(
F.affine_mask, F.affine_mask,
sample_inputs_fn=sample_inputs_affine_image_mask, sample_inputs_fn=sample_inputs_affine_mask,
reference_fn=reference_affine_mask, reference_fn=reference_affine_mask,
reference_inputs_fn=reference_inputs_resize_mask, reference_inputs_fn=reference_inputs_resize_mask,
closeness_kwargs=DEFAULT_IMAGE_CLOSENESS_KWARGS, closeness_kwargs=DEFAULT_IMAGE_CLOSENESS_KWARGS,
...@@ -537,6 +580,10 @@ KERNEL_INFOS.extend( ...@@ -537,6 +580,10 @@ KERNEL_INFOS.extend(
xfail_jit_python_scalar_arg("shear"), xfail_jit_python_scalar_arg("shear"),
], ],
), ),
KernelInfo(
F.affine_video,
sample_inputs_fn=sample_inputs_affine_video,
),
] ]
) )
...@@ -608,14 +655,28 @@ def reference_inputs_convert_color_space_image_tensor(): ...@@ -608,14 +655,28 @@ def reference_inputs_convert_color_space_image_tensor():
yield args_kwargs yield args_kwargs
KERNEL_INFOS.append( def sample_inputs_convert_color_space_video():
KernelInfo( color_spaces = [features.ColorSpace.GRAY, features.ColorSpace.RGB]
F.convert_color_space_image_tensor,
sample_inputs_fn=sample_inputs_convert_color_space_image_tensor, for old_color_space, new_color_space in cycle_over(color_spaces):
reference_fn=reference_convert_color_space_image_tensor, for video_loader in make_video_loaders(sizes=["random"], color_spaces=[old_color_space], num_frames=["random"]):
reference_inputs_fn=reference_inputs_convert_color_space_image_tensor, yield ArgsKwargs(video_loader, old_color_space=old_color_space, new_color_space=new_color_space)
closeness_kwargs=DEFAULT_IMAGE_CLOSENESS_KWARGS,
),
KERNEL_INFOS.extend(
[
KernelInfo(
F.convert_color_space_image_tensor,
sample_inputs_fn=sample_inputs_convert_color_space_image_tensor,
reference_fn=reference_convert_color_space_image_tensor,
reference_inputs_fn=reference_inputs_convert_color_space_image_tensor,
closeness_kwargs=DEFAULT_IMAGE_CLOSENESS_KWARGS,
),
KernelInfo(
F.convert_color_space_video,
sample_inputs_fn=sample_inputs_convert_color_space_video,
),
]
) )
...@@ -643,6 +704,11 @@ def sample_inputs_vertical_flip_mask(): ...@@ -643,6 +704,11 @@ def sample_inputs_vertical_flip_mask():
yield ArgsKwargs(image_loader) yield ArgsKwargs(image_loader)
def sample_inputs_vertical_flip_video():
for video_loader in make_video_loaders(sizes=["random"], num_frames=["random"]):
yield ArgsKwargs(video_loader)
KERNEL_INFOS.extend( KERNEL_INFOS.extend(
[ [
KernelInfo( KernelInfo(
...@@ -661,6 +727,10 @@ KERNEL_INFOS.extend( ...@@ -661,6 +727,10 @@ KERNEL_INFOS.extend(
F.vertical_flip_mask, F.vertical_flip_mask,
sample_inputs_fn=sample_inputs_vertical_flip_mask, sample_inputs_fn=sample_inputs_vertical_flip_mask,
), ),
KernelInfo(
F.vertical_flip_video,
sample_inputs_fn=sample_inputs_vertical_flip_video,
),
] ]
) )
...@@ -724,6 +794,11 @@ def reference_inputs_rotate_mask(): ...@@ -724,6 +794,11 @@ def reference_inputs_rotate_mask():
yield ArgsKwargs(mask_loader, angle=angle) yield ArgsKwargs(mask_loader, angle=angle)
def sample_inputs_rotate_video():
for video_loader in make_video_loaders(sizes=["random"], num_frames=["random"]):
yield ArgsKwargs(video_loader, angle=15.0)
KERNEL_INFOS.extend( KERNEL_INFOS.extend(
[ [
KernelInfo( KernelInfo(
...@@ -749,6 +824,10 @@ KERNEL_INFOS.extend( ...@@ -749,6 +824,10 @@ KERNEL_INFOS.extend(
reference_inputs_fn=reference_inputs_rotate_mask, reference_inputs_fn=reference_inputs_rotate_mask,
closeness_kwargs=DEFAULT_IMAGE_CLOSENESS_KWARGS, closeness_kwargs=DEFAULT_IMAGE_CLOSENESS_KWARGS,
), ),
KernelInfo(
F.rotate_video,
sample_inputs_fn=sample_inputs_rotate_video,
),
] ]
) )
...@@ -791,6 +870,11 @@ def reference_inputs_crop_mask(): ...@@ -791,6 +870,11 @@ def reference_inputs_crop_mask():
yield ArgsKwargs(mask_loader, **params) yield ArgsKwargs(mask_loader, **params)
def sample_inputs_crop_video():
for video_loader in make_video_loaders(sizes=[(16, 17)], num_frames=["random"]):
yield ArgsKwargs(video_loader, top=4, left=3, height=7, width=8)
KERNEL_INFOS.extend( KERNEL_INFOS.extend(
[ [
KernelInfo( KernelInfo(
...@@ -812,6 +896,10 @@ KERNEL_INFOS.extend( ...@@ -812,6 +896,10 @@ KERNEL_INFOS.extend(
reference_inputs_fn=reference_inputs_crop_mask, reference_inputs_fn=reference_inputs_crop_mask,
closeness_kwargs=DEFAULT_IMAGE_CLOSENESS_KWARGS, closeness_kwargs=DEFAULT_IMAGE_CLOSENESS_KWARGS,
), ),
KernelInfo(
F.crop_video,
sample_inputs_fn=sample_inputs_crop_video,
),
] ]
) )
...@@ -872,6 +960,11 @@ def reference_inputs_resized_crop_mask(): ...@@ -872,6 +960,11 @@ def reference_inputs_resized_crop_mask():
yield ArgsKwargs(mask_loader, **params) yield ArgsKwargs(mask_loader, **params)
def sample_inputs_resized_crop_video():
for video_loader in make_video_loaders(sizes=["random"], num_frames=["random"]):
yield ArgsKwargs(video_loader, **_RESIZED_CROP_PARAMS[0])
KERNEL_INFOS.extend( KERNEL_INFOS.extend(
[ [
KernelInfo( KernelInfo(
...@@ -892,6 +985,10 @@ KERNEL_INFOS.extend( ...@@ -892,6 +985,10 @@ KERNEL_INFOS.extend(
reference_inputs_fn=reference_inputs_resized_crop_mask, reference_inputs_fn=reference_inputs_resized_crop_mask,
closeness_kwargs=DEFAULT_IMAGE_CLOSENESS_KWARGS, closeness_kwargs=DEFAULT_IMAGE_CLOSENESS_KWARGS,
), ),
KernelInfo(
F.resized_crop_video,
sample_inputs_fn=sample_inputs_resized_crop_video,
),
] ]
) )
...@@ -965,6 +1062,11 @@ def reference_inputs_pad_mask(): ...@@ -965,6 +1062,11 @@ def reference_inputs_pad_mask():
yield ArgsKwargs(image_loader, fill=fill, **params) yield ArgsKwargs(image_loader, fill=fill, **params)
def sample_inputs_pad_video():
for video_loader in make_video_loaders(sizes=["random"], num_frames=["random"]):
yield ArgsKwargs(video_loader, padding=[1])
KERNEL_INFOS.extend( KERNEL_INFOS.extend(
[ [
KernelInfo( KernelInfo(
...@@ -996,6 +1098,10 @@ KERNEL_INFOS.extend( ...@@ -996,6 +1098,10 @@ KERNEL_INFOS.extend(
reference_inputs_fn=reference_inputs_pad_mask, reference_inputs_fn=reference_inputs_pad_mask,
closeness_kwargs=DEFAULT_IMAGE_CLOSENESS_KWARGS, closeness_kwargs=DEFAULT_IMAGE_CLOSENESS_KWARGS,
), ),
KernelInfo(
F.pad_video,
sample_inputs_fn=sample_inputs_pad_video,
),
] ]
) )
...@@ -1006,11 +1112,7 @@ _PERSPECTIVE_COEFFS = [ ...@@ -1006,11 +1112,7 @@ _PERSPECTIVE_COEFFS = [
def sample_inputs_perspective_image_tensor(): def sample_inputs_perspective_image_tensor():
for image_loader in make_image_loaders( for image_loader in make_image_loaders(sizes=["random"]):
sizes=["random"],
# FIXME: kernel should support arbitrary batch sizes
extra_dims=[(), (4,)],
):
for fill in [None, 128.0, 128, [12.0], [12.0 + c for c in range(image_loader.num_channels)]]: for fill in [None, 128.0, 128, [12.0], [12.0 + c for c in range(image_loader.num_channels)]]:
yield ArgsKwargs(image_loader, fill=fill, perspective_coeffs=_PERSPECTIVE_COEFFS[0]) yield ArgsKwargs(image_loader, fill=fill, perspective_coeffs=_PERSPECTIVE_COEFFS[0])
...@@ -1030,11 +1132,7 @@ def sample_inputs_perspective_bounding_box(): ...@@ -1030,11 +1132,7 @@ def sample_inputs_perspective_bounding_box():
def sample_inputs_perspective_mask(): def sample_inputs_perspective_mask():
for mask_loader in make_mask_loaders( for mask_loader in make_mask_loaders(sizes=["random"]):
sizes=["random"],
# FIXME: kernel should support arbitrary batch sizes
extra_dims=[(), (4,)],
):
yield ArgsKwargs(mask_loader, perspective_coeffs=_PERSPECTIVE_COEFFS[0]) yield ArgsKwargs(mask_loader, perspective_coeffs=_PERSPECTIVE_COEFFS[0])
...@@ -1045,6 +1143,11 @@ def reference_inputs_perspective_mask(): ...@@ -1045,6 +1143,11 @@ def reference_inputs_perspective_mask():
yield ArgsKwargs(mask_loader, perspective_coeffs=perspective_coeffs) yield ArgsKwargs(mask_loader, perspective_coeffs=perspective_coeffs)
def sample_inputs_perspective_video():
for video_loader in make_video_loaders(sizes=["random"], num_frames=["random"]):
yield ArgsKwargs(video_loader, perspective_coeffs=_PERSPECTIVE_COEFFS[0])
KERNEL_INFOS.extend( KERNEL_INFOS.extend(
[ [
KernelInfo( KernelInfo(
...@@ -1053,6 +1156,7 @@ KERNEL_INFOS.extend( ...@@ -1053,6 +1156,7 @@ KERNEL_INFOS.extend(
reference_fn=pil_reference_wrapper(F.perspective_image_pil), reference_fn=pil_reference_wrapper(F.perspective_image_pil),
reference_inputs_fn=reference_inputs_perspective_image_tensor, reference_inputs_fn=reference_inputs_perspective_image_tensor,
closeness_kwargs=DEFAULT_IMAGE_CLOSENESS_KWARGS, closeness_kwargs=DEFAULT_IMAGE_CLOSENESS_KWARGS,
test_marks=xfails_image_degenerate_or_multi_batch_dims,
), ),
KernelInfo( KernelInfo(
F.perspective_bounding_box, F.perspective_bounding_box,
...@@ -1064,6 +1168,11 @@ KERNEL_INFOS.extend( ...@@ -1064,6 +1168,11 @@ KERNEL_INFOS.extend(
reference_fn=pil_reference_wrapper(F.perspective_image_pil), reference_fn=pil_reference_wrapper(F.perspective_image_pil),
reference_inputs_fn=reference_inputs_perspective_mask, reference_inputs_fn=reference_inputs_perspective_mask,
closeness_kwargs=DEFAULT_IMAGE_CLOSENESS_KWARGS, closeness_kwargs=DEFAULT_IMAGE_CLOSENESS_KWARGS,
test_marks=xfails_image_degenerate_or_multi_batch_dims,
),
KernelInfo(
F.perspective_video,
sample_inputs_fn=sample_inputs_perspective_video,
), ),
] ]
) )
...@@ -1074,11 +1183,7 @@ def _get_elastic_displacement(image_size): ...@@ -1074,11 +1183,7 @@ def _get_elastic_displacement(image_size):
def sample_inputs_elastic_image_tensor(): def sample_inputs_elastic_image_tensor():
for image_loader in make_image_loaders( for image_loader in make_image_loaders(sizes=["random"]):
sizes=["random"],
# FIXME: kernel should support arbitrary batch sizes
extra_dims=[(), (4,)],
):
displacement = _get_elastic_displacement(image_loader.image_size) displacement = _get_elastic_displacement(image_loader.image_size)
for fill in [None, 128.0, 128, [12.0], [12.0 + c for c in range(image_loader.num_channels)]]: for fill in [None, 128.0, 128, [12.0], [12.0 + c for c in range(image_loader.num_channels)]]:
yield ArgsKwargs(image_loader, displacement=displacement, fill=fill) yield ArgsKwargs(image_loader, displacement=displacement, fill=fill)
...@@ -1109,11 +1214,7 @@ def sample_inputs_elastic_bounding_box(): ...@@ -1109,11 +1214,7 @@ def sample_inputs_elastic_bounding_box():
def sample_inputs_elastic_mask(): def sample_inputs_elastic_mask():
for mask_loader in make_mask_loaders( for mask_loader in make_mask_loaders(sizes=["random"]):
sizes=["random"],
# FIXME: kernel should support arbitrary batch sizes
extra_dims=[(), (4,)],
):
displacement = _get_elastic_displacement(mask_loader.shape[-2:]) displacement = _get_elastic_displacement(mask_loader.shape[-2:])
yield ArgsKwargs(mask_loader, displacement=displacement) yield ArgsKwargs(mask_loader, displacement=displacement)
...@@ -1124,6 +1225,12 @@ def reference_inputs_elastic_mask(): ...@@ -1124,6 +1225,12 @@ def reference_inputs_elastic_mask():
yield ArgsKwargs(mask_loader, displacement=displacement) yield ArgsKwargs(mask_loader, displacement=displacement)
def sample_inputs_elastic_video():
for video_loader in make_video_loaders(sizes=["random"], num_frames=["random"]):
displacement = _get_elastic_displacement(video_loader.shape[-2:])
yield ArgsKwargs(video_loader, displacement=displacement)
KERNEL_INFOS.extend( KERNEL_INFOS.extend(
[ [
KernelInfo( KernelInfo(
...@@ -1132,6 +1239,7 @@ KERNEL_INFOS.extend( ...@@ -1132,6 +1239,7 @@ KERNEL_INFOS.extend(
reference_fn=pil_reference_wrapper(F.elastic_image_pil), reference_fn=pil_reference_wrapper(F.elastic_image_pil),
reference_inputs_fn=reference_inputs_elastic_image_tensor, reference_inputs_fn=reference_inputs_elastic_image_tensor,
closeness_kwargs=DEFAULT_IMAGE_CLOSENESS_KWARGS, closeness_kwargs=DEFAULT_IMAGE_CLOSENESS_KWARGS,
test_marks=xfails_image_degenerate_or_multi_batch_dims,
), ),
KernelInfo( KernelInfo(
F.elastic_bounding_box, F.elastic_bounding_box,
...@@ -1143,6 +1251,11 @@ KERNEL_INFOS.extend( ...@@ -1143,6 +1251,11 @@ KERNEL_INFOS.extend(
reference_fn=pil_reference_wrapper(F.elastic_image_pil), reference_fn=pil_reference_wrapper(F.elastic_image_pil),
reference_inputs_fn=reference_inputs_elastic_mask, reference_inputs_fn=reference_inputs_elastic_mask,
closeness_kwargs=DEFAULT_IMAGE_CLOSENESS_KWARGS, closeness_kwargs=DEFAULT_IMAGE_CLOSENESS_KWARGS,
test_marks=xfails_image_degenerate_or_multi_batch_dims,
),
KernelInfo(
F.elastic_video,
sample_inputs_fn=sample_inputs_elastic_video,
), ),
] ]
) )
...@@ -1195,6 +1308,12 @@ def reference_inputs_center_crop_mask(): ...@@ -1195,6 +1308,12 @@ def reference_inputs_center_crop_mask():
yield ArgsKwargs(mask_loader, output_size=output_size) yield ArgsKwargs(mask_loader, output_size=output_size)
def sample_inputs_center_crop_video():
for video_loader in make_video_loaders(sizes=["random"], num_frames=["random"]):
height, width = video_loader.shape[-2:]
yield ArgsKwargs(video_loader, output_size=(height // 2, width // 2))
KERNEL_INFOS.extend( KERNEL_INFOS.extend(
[ [
KernelInfo( KernelInfo(
...@@ -1224,17 +1343,17 @@ KERNEL_INFOS.extend( ...@@ -1224,17 +1343,17 @@ KERNEL_INFOS.extend(
xfail_jit_integer_size("output_size"), xfail_jit_integer_size("output_size"),
], ],
), ),
KernelInfo(
F.center_crop_video,
sample_inputs_fn=sample_inputs_center_crop_video,
),
] ]
) )
def sample_inputs_gaussian_blur_image_tensor(): def sample_inputs_gaussian_blur_image_tensor():
make_gaussian_blur_image_loaders = functools.partial( make_gaussian_blur_image_loaders = functools.partial(
make_image_loaders, make_image_loaders, sizes=["random"], color_spaces=[features.ColorSpace.RGB]
sizes=["random"],
color_spaces=[features.ColorSpace.RGB],
# FIXME: kernel should support arbitrary batch sizes
extra_dims=[(), (4,)],
) )
for image_loader, kernel_size in itertools.product(make_gaussian_blur_image_loaders(), [5, (3, 3), [3, 3]]): for image_loader, kernel_size in itertools.product(make_gaussian_blur_image_loaders(), [5, (3, 3), [3, 3]]):
...@@ -1246,26 +1365,34 @@ def sample_inputs_gaussian_blur_image_tensor(): ...@@ -1246,26 +1365,34 @@ def sample_inputs_gaussian_blur_image_tensor():
yield ArgsKwargs(image_loader, kernel_size=5, sigma=sigma) yield ArgsKwargs(image_loader, kernel_size=5, sigma=sigma)
KERNEL_INFOS.append( def sample_inputs_gaussian_blur_video():
KernelInfo( for video_loader in make_video_loaders(sizes=["random"], num_frames=["random"]):
F.gaussian_blur_image_tensor, yield ArgsKwargs(video_loader, kernel_size=[3, 3])
sample_inputs_fn=sample_inputs_gaussian_blur_image_tensor,
closeness_kwargs=DEFAULT_IMAGE_CLOSENESS_KWARGS,
test_marks=[ KERNEL_INFOS.extend(
xfail_jit_python_scalar_arg("kernel_size"), [
xfail_jit_python_scalar_arg("sigma"), KernelInfo(
], F.gaussian_blur_image_tensor,
) sample_inputs_fn=sample_inputs_gaussian_blur_image_tensor,
closeness_kwargs=DEFAULT_IMAGE_CLOSENESS_KWARGS,
test_marks=[
xfail_jit_python_scalar_arg("kernel_size"),
xfail_jit_python_scalar_arg("sigma"),
*xfails_image_degenerate_or_multi_batch_dims,
],
),
KernelInfo(
F.gaussian_blur_video,
sample_inputs_fn=sample_inputs_gaussian_blur_video,
),
]
) )
def sample_inputs_equalize_image_tensor(): def sample_inputs_equalize_image_tensor():
for image_loader in make_image_loaders( for image_loader in make_image_loaders(
sizes=["random"], sizes=["random"], color_spaces=(features.ColorSpace.GRAY, features.ColorSpace.RGB), dtypes=[torch.uint8]
# FIXME: kernel should support arbitrary batch sizes
extra_dims=[(), (4,)],
color_spaces=(features.ColorSpace.GRAY, features.ColorSpace.RGB),
dtypes=[torch.uint8],
): ):
yield ArgsKwargs(image_loader) yield ArgsKwargs(image_loader)
...@@ -1277,15 +1404,26 @@ def reference_inputs_equalize_image_tensor(): ...@@ -1277,15 +1404,26 @@ def reference_inputs_equalize_image_tensor():
yield ArgsKwargs(image_loader) yield ArgsKwargs(image_loader)
KERNEL_INFOS.append( def sample_inputs_equalize_video():
KernelInfo( for video_loader in make_video_loaders(sizes=["random"], num_frames=["random"]):
F.equalize_image_tensor, yield ArgsKwargs(video_loader)
kernel_name="equalize_image_tensor",
sample_inputs_fn=sample_inputs_equalize_image_tensor,
reference_fn=pil_reference_wrapper(F.equalize_image_pil), KERNEL_INFOS.extend(
reference_inputs_fn=reference_inputs_equalize_image_tensor, [
closeness_kwargs=DEFAULT_IMAGE_CLOSENESS_KWARGS, KernelInfo(
) F.equalize_image_tensor,
kernel_name="equalize_image_tensor",
sample_inputs_fn=sample_inputs_equalize_image_tensor,
reference_fn=pil_reference_wrapper(F.equalize_image_pil),
reference_inputs_fn=reference_inputs_equalize_image_tensor,
closeness_kwargs=DEFAULT_IMAGE_CLOSENESS_KWARGS,
),
KernelInfo(
F.equalize_video,
sample_inputs_fn=sample_inputs_equalize_video,
),
]
) )
...@@ -1303,15 +1441,26 @@ def reference_inputs_invert_image_tensor(): ...@@ -1303,15 +1441,26 @@ def reference_inputs_invert_image_tensor():
yield ArgsKwargs(image_loader) yield ArgsKwargs(image_loader)
KERNEL_INFOS.append( def sample_inputs_invert_video():
KernelInfo( for video_loader in make_video_loaders(sizes=["random"], num_frames=["random"]):
F.invert_image_tensor, yield ArgsKwargs(video_loader)
kernel_name="invert_image_tensor",
sample_inputs_fn=sample_inputs_invert_image_tensor,
reference_fn=pil_reference_wrapper(F.invert_image_pil), KERNEL_INFOS.extend(
reference_inputs_fn=reference_inputs_invert_image_tensor, [
closeness_kwargs=DEFAULT_IMAGE_CLOSENESS_KWARGS, KernelInfo(
) F.invert_image_tensor,
kernel_name="invert_image_tensor",
sample_inputs_fn=sample_inputs_invert_image_tensor,
reference_fn=pil_reference_wrapper(F.invert_image_pil),
reference_inputs_fn=reference_inputs_invert_image_tensor,
closeness_kwargs=DEFAULT_IMAGE_CLOSENESS_KWARGS,
),
KernelInfo(
F.invert_video,
sample_inputs_fn=sample_inputs_invert_video,
),
]
) )
...@@ -1335,15 +1484,26 @@ def reference_inputs_posterize_image_tensor(): ...@@ -1335,15 +1484,26 @@ def reference_inputs_posterize_image_tensor():
yield ArgsKwargs(image_loader, bits=bits) yield ArgsKwargs(image_loader, bits=bits)
KERNEL_INFOS.append( def sample_inputs_posterize_video():
KernelInfo( for video_loader in make_video_loaders(sizes=["random"], num_frames=["random"]):
F.posterize_image_tensor, yield ArgsKwargs(video_loader, bits=_POSTERIZE_BITS[0])
kernel_name="posterize_image_tensor",
sample_inputs_fn=sample_inputs_posterize_image_tensor,
reference_fn=pil_reference_wrapper(F.posterize_image_pil), KERNEL_INFOS.extend(
reference_inputs_fn=reference_inputs_posterize_image_tensor, [
closeness_kwargs=DEFAULT_IMAGE_CLOSENESS_KWARGS, KernelInfo(
) F.posterize_image_tensor,
kernel_name="posterize_image_tensor",
sample_inputs_fn=sample_inputs_posterize_image_tensor,
reference_fn=pil_reference_wrapper(F.posterize_image_pil),
reference_inputs_fn=reference_inputs_posterize_image_tensor,
closeness_kwargs=DEFAULT_IMAGE_CLOSENESS_KWARGS,
),
KernelInfo(
F.posterize_video,
sample_inputs_fn=sample_inputs_posterize_video,
),
]
) )
...@@ -1368,15 +1528,26 @@ def reference_inputs_solarize_image_tensor(): ...@@ -1368,15 +1528,26 @@ def reference_inputs_solarize_image_tensor():
yield ArgsKwargs(image_loader, threshold=threshold) yield ArgsKwargs(image_loader, threshold=threshold)
KERNEL_INFOS.append( def sample_inputs_solarize_video():
KernelInfo( for video_loader in make_video_loaders(sizes=["random"], num_frames=["random"]):
F.solarize_image_tensor, yield ArgsKwargs(video_loader, threshold=next(_get_solarize_thresholds(video_loader.dtype)))
kernel_name="solarize_image_tensor",
sample_inputs_fn=sample_inputs_solarize_image_tensor,
reference_fn=pil_reference_wrapper(F.solarize_image_pil), KERNEL_INFOS.extend(
reference_inputs_fn=reference_inputs_solarize_image_tensor, [
closeness_kwargs=DEFAULT_IMAGE_CLOSENESS_KWARGS, KernelInfo(
) F.solarize_image_tensor,
kernel_name="solarize_image_tensor",
sample_inputs_fn=sample_inputs_solarize_image_tensor,
reference_fn=pil_reference_wrapper(F.solarize_image_pil),
reference_inputs_fn=reference_inputs_solarize_image_tensor,
closeness_kwargs=DEFAULT_IMAGE_CLOSENESS_KWARGS,
),
KernelInfo(
F.solarize_video,
sample_inputs_fn=sample_inputs_solarize_video,
),
]
) )
...@@ -1394,15 +1565,26 @@ def reference_inputs_autocontrast_image_tensor(): ...@@ -1394,15 +1565,26 @@ def reference_inputs_autocontrast_image_tensor():
yield ArgsKwargs(image_loader) yield ArgsKwargs(image_loader)
KERNEL_INFOS.append( def sample_inputs_autocontrast_video():
KernelInfo( for video_loader in make_video_loaders(sizes=["random"], num_frames=["random"]):
F.autocontrast_image_tensor, yield ArgsKwargs(video_loader)
kernel_name="autocontrast_image_tensor",
sample_inputs_fn=sample_inputs_autocontrast_image_tensor,
reference_fn=pil_reference_wrapper(F.autocontrast_image_pil), KERNEL_INFOS.extend(
reference_inputs_fn=reference_inputs_autocontrast_image_tensor, [
closeness_kwargs=DEFAULT_IMAGE_CLOSENESS_KWARGS, KernelInfo(
) F.autocontrast_image_tensor,
kernel_name="autocontrast_image_tensor",
sample_inputs_fn=sample_inputs_autocontrast_image_tensor,
reference_fn=pil_reference_wrapper(F.autocontrast_image_pil),
reference_inputs_fn=reference_inputs_autocontrast_image_tensor,
closeness_kwargs=DEFAULT_IMAGE_CLOSENESS_KWARGS,
),
KernelInfo(
F.autocontrast_video,
sample_inputs_fn=sample_inputs_autocontrast_video,
),
]
) )
_ADJUST_SHARPNESS_FACTORS = [0.1, 0.5] _ADJUST_SHARPNESS_FACTORS = [0.1, 0.5]
...@@ -1412,8 +1594,6 @@ def sample_inputs_adjust_sharpness_image_tensor(): ...@@ -1412,8 +1594,6 @@ def sample_inputs_adjust_sharpness_image_tensor():
for image_loader in make_image_loaders( for image_loader in make_image_loaders(
sizes=["random", (2, 2)], sizes=["random", (2, 2)],
color_spaces=(features.ColorSpace.GRAY, features.ColorSpace.RGB), color_spaces=(features.ColorSpace.GRAY, features.ColorSpace.RGB),
# FIXME: kernel should support arbitrary batch sizes
extra_dims=[(), (4,)],
): ):
yield ArgsKwargs(image_loader, sharpness_factor=_ADJUST_SHARPNESS_FACTORS[0]) yield ArgsKwargs(image_loader, sharpness_factor=_ADJUST_SHARPNESS_FACTORS[0])
...@@ -1426,15 +1606,26 @@ def reference_inputs_adjust_sharpness_image_tensor(): ...@@ -1426,15 +1606,26 @@ def reference_inputs_adjust_sharpness_image_tensor():
yield ArgsKwargs(image_loader, sharpness_factor=sharpness_factor) yield ArgsKwargs(image_loader, sharpness_factor=sharpness_factor)
KERNEL_INFOS.append( def sample_inputs_adjust_sharpness_video():
KernelInfo( for video_loader in make_video_loaders(sizes=["random"], num_frames=["random"]):
F.adjust_sharpness_image_tensor, yield ArgsKwargs(video_loader, sharpness_factor=_ADJUST_SHARPNESS_FACTORS[0])
kernel_name="adjust_sharpness_image_tensor",
sample_inputs_fn=sample_inputs_adjust_sharpness_image_tensor,
reference_fn=pil_reference_wrapper(F.adjust_sharpness_image_pil), KERNEL_INFOS.extend(
reference_inputs_fn=reference_inputs_adjust_sharpness_image_tensor, [
closeness_kwargs=DEFAULT_IMAGE_CLOSENESS_KWARGS, KernelInfo(
) F.adjust_sharpness_image_tensor,
kernel_name="adjust_sharpness_image_tensor",
sample_inputs_fn=sample_inputs_adjust_sharpness_image_tensor,
reference_fn=pil_reference_wrapper(F.adjust_sharpness_image_pil),
reference_inputs_fn=reference_inputs_adjust_sharpness_image_tensor,
closeness_kwargs=DEFAULT_IMAGE_CLOSENESS_KWARGS,
),
KernelInfo(
F.adjust_sharpness_video,
sample_inputs_fn=sample_inputs_adjust_sharpness_video,
),
]
) )
...@@ -1446,12 +1637,26 @@ def sample_inputs_erase_image_tensor(): ...@@ -1446,12 +1637,26 @@ def sample_inputs_erase_image_tensor():
yield ArgsKwargs(image_loader, i=1, j=2, h=h, w=w, v=v) yield ArgsKwargs(image_loader, i=1, j=2, h=h, w=w, v=v)
KERNEL_INFOS.append( def sample_inputs_erase_video():
KernelInfo( for video_loader in make_video_loaders(sizes=["random"], num_frames=["random"]):
F.erase_image_tensor, # FIXME: make the parameters more diverse
kernel_name="erase_image_tensor", h, w = 6, 7
sample_inputs_fn=sample_inputs_erase_image_tensor, v = torch.rand(video_loader.num_channels, h, w)
) yield ArgsKwargs(video_loader, i=1, j=2, h=h, w=w, v=v)
KERNEL_INFOS.extend(
[
KernelInfo(
F.erase_image_tensor,
kernel_name="erase_image_tensor",
sample_inputs_fn=sample_inputs_erase_image_tensor,
),
KernelInfo(
F.erase_video,
sample_inputs_fn=sample_inputs_erase_video,
),
]
) )
_ADJUST_BRIGHTNESS_FACTORS = [0.1, 0.5] _ADJUST_BRIGHTNESS_FACTORS = [0.1, 0.5]
...@@ -1472,15 +1677,26 @@ def reference_inputs_adjust_brightness_image_tensor(): ...@@ -1472,15 +1677,26 @@ def reference_inputs_adjust_brightness_image_tensor():
yield ArgsKwargs(image_loader, brightness_factor=brightness_factor) yield ArgsKwargs(image_loader, brightness_factor=brightness_factor)
KERNEL_INFOS.append( def sample_inputs_adjust_brightness_video():
KernelInfo( for video_loader in make_video_loaders(sizes=["random"], num_frames=["random"]):
F.adjust_brightness_image_tensor, yield ArgsKwargs(video_loader, brightness_factor=_ADJUST_BRIGHTNESS_FACTORS[0])
kernel_name="adjust_brightness_image_tensor",
sample_inputs_fn=sample_inputs_adjust_brightness_image_tensor,
reference_fn=pil_reference_wrapper(F.adjust_brightness_image_pil), KERNEL_INFOS.extend(
reference_inputs_fn=reference_inputs_adjust_brightness_image_tensor, [
closeness_kwargs=DEFAULT_IMAGE_CLOSENESS_KWARGS, KernelInfo(
) F.adjust_brightness_image_tensor,
kernel_name="adjust_brightness_image_tensor",
sample_inputs_fn=sample_inputs_adjust_brightness_image_tensor,
reference_fn=pil_reference_wrapper(F.adjust_brightness_image_pil),
reference_inputs_fn=reference_inputs_adjust_brightness_image_tensor,
closeness_kwargs=DEFAULT_IMAGE_CLOSENESS_KWARGS,
),
KernelInfo(
F.adjust_brightness_video,
sample_inputs_fn=sample_inputs_adjust_brightness_video,
),
]
) )
...@@ -1502,15 +1718,26 @@ def reference_inputs_adjust_contrast_image_tensor(): ...@@ -1502,15 +1718,26 @@ def reference_inputs_adjust_contrast_image_tensor():
yield ArgsKwargs(image_loader, contrast_factor=contrast_factor) yield ArgsKwargs(image_loader, contrast_factor=contrast_factor)
KERNEL_INFOS.append( def sample_inputs_adjust_contrast_video():
KernelInfo( for video_loader in make_video_loaders(sizes=["random"], num_frames=["random"]):
F.adjust_contrast_image_tensor, yield ArgsKwargs(video_loader, contrast_factor=_ADJUST_CONTRAST_FACTORS[0])
kernel_name="adjust_contrast_image_tensor",
sample_inputs_fn=sample_inputs_adjust_contrast_image_tensor,
reference_fn=pil_reference_wrapper(F.adjust_contrast_image_pil), KERNEL_INFOS.extend(
reference_inputs_fn=reference_inputs_adjust_contrast_image_tensor, [
closeness_kwargs=DEFAULT_IMAGE_CLOSENESS_KWARGS, KernelInfo(
) F.adjust_contrast_image_tensor,
kernel_name="adjust_contrast_image_tensor",
sample_inputs_fn=sample_inputs_adjust_contrast_image_tensor,
reference_fn=pil_reference_wrapper(F.adjust_contrast_image_pil),
reference_inputs_fn=reference_inputs_adjust_contrast_image_tensor,
closeness_kwargs=DEFAULT_IMAGE_CLOSENESS_KWARGS,
),
KernelInfo(
F.adjust_contrast_video,
sample_inputs_fn=sample_inputs_adjust_contrast_video,
),
]
) )
_ADJUST_GAMMA_GAMMAS_GAINS = [ _ADJUST_GAMMA_GAMMAS_GAINS = [
...@@ -1535,15 +1762,27 @@ def reference_inputs_adjust_gamma_image_tensor(): ...@@ -1535,15 +1762,27 @@ def reference_inputs_adjust_gamma_image_tensor():
yield ArgsKwargs(image_loader, gamma=gamma, gain=gain) yield ArgsKwargs(image_loader, gamma=gamma, gain=gain)
KERNEL_INFOS.append( def sample_inputs_adjust_gamma_video():
KernelInfo( gamma, gain = _ADJUST_GAMMA_GAMMAS_GAINS[0]
F.adjust_gamma_image_tensor, for video_loader in make_video_loaders(sizes=["random"], num_frames=["random"]):
kernel_name="adjust_gamma_image_tensor", yield ArgsKwargs(video_loader, gamma=gamma, gain=gain)
sample_inputs_fn=sample_inputs_adjust_gamma_image_tensor,
reference_fn=pil_reference_wrapper(F.adjust_gamma_image_pil),
reference_inputs_fn=reference_inputs_adjust_gamma_image_tensor, KERNEL_INFOS.extend(
closeness_kwargs=DEFAULT_IMAGE_CLOSENESS_KWARGS, [
) KernelInfo(
F.adjust_gamma_image_tensor,
kernel_name="adjust_gamma_image_tensor",
sample_inputs_fn=sample_inputs_adjust_gamma_image_tensor,
reference_fn=pil_reference_wrapper(F.adjust_gamma_image_pil),
reference_inputs_fn=reference_inputs_adjust_gamma_image_tensor,
closeness_kwargs=DEFAULT_IMAGE_CLOSENESS_KWARGS,
),
KernelInfo(
F.adjust_gamma_video,
sample_inputs_fn=sample_inputs_adjust_gamma_video,
),
]
) )
...@@ -1565,15 +1804,26 @@ def reference_inputs_adjust_hue_image_tensor(): ...@@ -1565,15 +1804,26 @@ def reference_inputs_adjust_hue_image_tensor():
yield ArgsKwargs(image_loader, hue_factor=hue_factor) yield ArgsKwargs(image_loader, hue_factor=hue_factor)
KERNEL_INFOS.append( def sample_inputs_adjust_hue_video():
KernelInfo( for video_loader in make_video_loaders(sizes=["random"], num_frames=["random"]):
F.adjust_hue_image_tensor, yield ArgsKwargs(video_loader, hue_factor=_ADJUST_HUE_FACTORS[0])
kernel_name="adjust_hue_image_tensor",
sample_inputs_fn=sample_inputs_adjust_hue_image_tensor,
reference_fn=pil_reference_wrapper(F.adjust_hue_image_pil), KERNEL_INFOS.extend(
reference_inputs_fn=reference_inputs_adjust_hue_image_tensor, [
closeness_kwargs=DEFAULT_IMAGE_CLOSENESS_KWARGS, KernelInfo(
) F.adjust_hue_image_tensor,
kernel_name="adjust_hue_image_tensor",
sample_inputs_fn=sample_inputs_adjust_hue_image_tensor,
reference_fn=pil_reference_wrapper(F.adjust_hue_image_pil),
reference_inputs_fn=reference_inputs_adjust_hue_image_tensor,
closeness_kwargs=DEFAULT_IMAGE_CLOSENESS_KWARGS,
),
KernelInfo(
F.adjust_hue_video,
sample_inputs_fn=sample_inputs_adjust_hue_video,
),
]
) )
_ADJUST_SATURATION_FACTORS = [0.1, 0.5] _ADJUST_SATURATION_FACTORS = [0.1, 0.5]
...@@ -1594,15 +1844,26 @@ def reference_inputs_adjust_saturation_image_tensor(): ...@@ -1594,15 +1844,26 @@ def reference_inputs_adjust_saturation_image_tensor():
yield ArgsKwargs(image_loader, saturation_factor=saturation_factor) yield ArgsKwargs(image_loader, saturation_factor=saturation_factor)
KERNEL_INFOS.append( def sample_inputs_adjust_saturation_video():
KernelInfo( for video_loader in make_video_loaders(sizes=["random"], num_frames=["random"]):
F.adjust_saturation_image_tensor, yield ArgsKwargs(video_loader, saturation_factor=_ADJUST_SATURATION_FACTORS[0])
kernel_name="adjust_saturation_image_tensor",
sample_inputs_fn=sample_inputs_adjust_saturation_image_tensor,
reference_fn=pil_reference_wrapper(F.adjust_saturation_image_pil), KERNEL_INFOS.extend(
reference_inputs_fn=reference_inputs_adjust_saturation_image_tensor, [
closeness_kwargs=DEFAULT_IMAGE_CLOSENESS_KWARGS, KernelInfo(
) F.adjust_saturation_image_tensor,
kernel_name="adjust_saturation_image_tensor",
sample_inputs_fn=sample_inputs_adjust_saturation_image_tensor,
reference_fn=pil_reference_wrapper(F.adjust_saturation_image_pil),
reference_inputs_fn=reference_inputs_adjust_saturation_image_tensor,
closeness_kwargs=DEFAULT_IMAGE_CLOSENESS_KWARGS,
),
KernelInfo(
F.adjust_saturation_video,
sample_inputs_fn=sample_inputs_adjust_saturation_video,
),
]
) )
...@@ -1702,10 +1963,24 @@ def sample_inputs_normalize_image_tensor(): ...@@ -1702,10 +1963,24 @@ def sample_inputs_normalize_image_tensor():
yield ArgsKwargs(image_loader, mean=mean, std=std) yield ArgsKwargs(image_loader, mean=mean, std=std)
KERNEL_INFOS.append( def sample_inputs_normalize_video():
KernelInfo( mean, std = _NORMALIZE_MEANS_STDS[0]
F.normalize_image_tensor, for video_loader in make_video_loaders(
kernel_name="normalize_image_tensor", sizes=["random"], color_spaces=[features.ColorSpace.RGB], num_frames=["random"], dtypes=[torch.float32]
sample_inputs_fn=sample_inputs_normalize_image_tensor, ):
) yield ArgsKwargs(video_loader, mean=mean, std=std)
KERNEL_INFOS.extend(
[
KernelInfo(
F.normalize_image_tensor,
kernel_name="normalize_image_tensor",
sample_inputs_fn=sample_inputs_normalize_image_tensor,
),
KernelInfo(
F.normalize_video,
sample_inputs_fn=sample_inputs_normalize_video,
),
]
) )
...@@ -17,6 +17,7 @@ from prototype_common_utils import ( ...@@ -17,6 +17,7 @@ from prototype_common_utils import (
make_masks, make_masks,
make_one_hot_labels, make_one_hot_labels,
make_segmentation_mask, make_segmentation_mask,
make_videos,
) )
from torchvision.ops.boxes import box_iou from torchvision.ops.boxes import box_iou
from torchvision.prototype import features, transforms from torchvision.prototype import features, transforms
...@@ -65,6 +66,7 @@ def parametrize_from_transforms(*transforms): ...@@ -65,6 +66,7 @@ def parametrize_from_transforms(*transforms):
make_vanilla_tensor_images, make_vanilla_tensor_images,
make_pil_images, make_pil_images,
make_masks, make_masks,
make_videos,
]: ]:
inputs = list(creation_fn()) inputs = list(creation_fn())
try: try:
...@@ -155,12 +157,14 @@ class TestSmoke: ...@@ -155,12 +157,14 @@ class TestSmoke:
features.ColorSpace.RGB, features.ColorSpace.RGB,
], ],
dtypes=[torch.uint8], dtypes=[torch.uint8],
extra_dims=[(4,)], extra_dims=[(), (4,)],
**(dict(num_frames=["random"]) if fn is make_videos else dict()),
) )
for fn in [ for fn in [
make_images, make_images,
make_vanilla_tensor_images, make_vanilla_tensor_images,
make_pil_images, make_pil_images,
make_videos,
] ]
), ),
) )
...@@ -184,6 +188,7 @@ class TestSmoke: ...@@ -184,6 +188,7 @@ class TestSmoke:
for fn in [ for fn in [
make_images, make_images,
make_vanilla_tensor_images, make_vanilla_tensor_images,
make_videos,
] ]
), ),
), ),
...@@ -200,6 +205,7 @@ class TestSmoke: ...@@ -200,6 +205,7 @@ class TestSmoke:
make_images(extra_dims=[(4,)]), make_images(extra_dims=[(4,)]),
make_vanilla_tensor_images(), make_vanilla_tensor_images(),
make_pil_images(), make_pil_images(),
make_videos(extra_dims=[()]),
), ),
) )
] ]
...@@ -218,6 +224,7 @@ class TestSmoke: ...@@ -218,6 +224,7 @@ class TestSmoke:
make_images, make_images,
make_vanilla_tensor_images, make_vanilla_tensor_images,
make_pil_images, make_pil_images,
make_videos,
) )
] ]
), ),
......
...@@ -129,6 +129,7 @@ class TestKernels: ...@@ -129,6 +129,7 @@ class TestKernels:
# type all kernels should also work without differentiating between the two. Thus, we go with 2 here as # type all kernels should also work without differentiating between the two. Thus, we go with 2 here as
# common ground. # common ground.
features.Mask: 2, features.Mask: 2,
features.Video: 4,
}.get(feature_type) }.get(feature_type)
if data_dims is None: if data_dims is None:
raise pytest.UsageError( raise pytest.UsageError(
......
...@@ -13,3 +13,4 @@ from ._image import ( ...@@ -13,3 +13,4 @@ from ._image import (
) )
from ._label import Label, OneHotLabel from ._label import Label, OneHotLabel
from ._mask import Mask from ._mask import Mask
from ._video import ImageOrVideoType, ImageOrVideoTypeJIT, TensorImageOrVideoType, TensorImageOrVideoTypeJIT, Video
from __future__ import annotations
import warnings
from typing import Any, cast, List, Optional, Tuple, Union
import torch
from torchvision.transforms.functional import InterpolationMode
from ._feature import _Feature, FillTypeJIT
from ._image import ColorSpace, ImageType, ImageTypeJIT, TensorImageType, TensorImageTypeJIT
class Video(_Feature):
color_space: ColorSpace
def __new__(
cls,
data: Any,
*,
color_space: Optional[Union[ColorSpace, str]] = None,
dtype: Optional[torch.dtype] = None,
device: Optional[Union[torch.device, str, int]] = None,
requires_grad: bool = False,
) -> Video:
data = torch.as_tensor(data, dtype=dtype, device=device)
if data.ndim < 4:
raise ValueError
video = super().__new__(cls, data, requires_grad=requires_grad)
if color_space is None:
color_space = ColorSpace.from_tensor_shape(video.shape) # type: ignore[arg-type]
if color_space == ColorSpace.OTHER:
warnings.warn("Unable to guess a specific color space. Consider passing it explicitly.")
elif isinstance(color_space, str):
color_space = ColorSpace.from_str(color_space.upper())
elif not isinstance(color_space, ColorSpace):
raise ValueError
video.color_space = color_space
return video
def __repr__(self, *, tensor_contents: Any = None) -> str: # type: ignore[override]
return self._make_repr(color_space=self.color_space)
@classmethod
def new_like(
cls, other: Video, data: Any, *, color_space: Optional[Union[ColorSpace, str]] = None, **kwargs: Any
) -> Video:
return super().new_like(
other, data, color_space=color_space if color_space is not None else other.color_space, **kwargs
)
# TODO: rename this (and all instances of this term to spatial size)
@property
def image_size(self) -> Tuple[int, int]:
return cast(Tuple[int, int], tuple(self.shape[-2:]))
@property
def num_channels(self) -> int:
return self.shape[-3]
@property
def num_frames(self) -> int:
return self.shape[-4]
def to_color_space(self, color_space: Union[str, ColorSpace], copy: bool = True) -> Video:
if isinstance(color_space, str):
color_space = ColorSpace.from_str(color_space.upper())
return Video.new_like(
self,
self._F.convert_color_space_video(
self, old_color_space=self.color_space, new_color_space=color_space, copy=copy
),
color_space=color_space,
)
def horizontal_flip(self) -> Video:
output = self._F.horizontal_flip_video(self)
return Video.new_like(self, output)
def vertical_flip(self) -> Video:
output = self._F.vertical_flip_video(self)
return Video.new_like(self, output)
def resize( # type: ignore[override]
self,
size: List[int],
interpolation: InterpolationMode = InterpolationMode.BILINEAR,
max_size: Optional[int] = None,
antialias: bool = False,
) -> Video:
output = self._F.resize_video(self, size, interpolation=interpolation, max_size=max_size, antialias=antialias)
return Video.new_like(self, output)
def crop(self, top: int, left: int, height: int, width: int) -> Video:
output = self._F.crop_video(self, top, left, height, width)
return Video.new_like(self, output)
def center_crop(self, output_size: List[int]) -> Video:
output = self._F.center_crop_video(self, output_size=output_size)
return Video.new_like(self, output)
def resized_crop(
self,
top: int,
left: int,
height: int,
width: int,
size: List[int],
interpolation: InterpolationMode = InterpolationMode.BILINEAR,
antialias: bool = False,
) -> Video:
output = self._F.resized_crop_video(
self, top, left, height, width, size=list(size), interpolation=interpolation, antialias=antialias
)
return Video.new_like(self, output)
def pad(
self,
padding: Union[int, List[int]],
fill: FillTypeJIT = None,
padding_mode: str = "constant",
) -> Video:
output = self._F.pad_video(self, padding, fill=fill, padding_mode=padding_mode)
return Video.new_like(self, output)
def rotate(
self,
angle: float,
interpolation: InterpolationMode = InterpolationMode.NEAREST,
expand: bool = False,
fill: FillTypeJIT = None,
center: Optional[List[float]] = None,
) -> Video:
output = self._F._geometry.rotate_video(
self, angle, interpolation=interpolation, expand=expand, fill=fill, center=center
)
return Video.new_like(self, output)
def affine(
self,
angle: Union[int, float],
translate: List[float],
scale: float,
shear: List[float],
interpolation: InterpolationMode = InterpolationMode.NEAREST,
fill: FillTypeJIT = None,
center: Optional[List[float]] = None,
) -> Video:
output = self._F._geometry.affine_video(
self,
angle,
translate=translate,
scale=scale,
shear=shear,
interpolation=interpolation,
fill=fill,
center=center,
)
return Video.new_like(self, output)
def perspective(
self,
perspective_coeffs: List[float],
interpolation: InterpolationMode = InterpolationMode.BILINEAR,
fill: FillTypeJIT = None,
) -> Video:
output = self._F._geometry.perspective_video(self, perspective_coeffs, interpolation=interpolation, fill=fill)
return Video.new_like(self, output)
def elastic(
self,
displacement: torch.Tensor,
interpolation: InterpolationMode = InterpolationMode.BILINEAR,
fill: FillTypeJIT = None,
) -> Video:
output = self._F._geometry.elastic_video(self, displacement, interpolation=interpolation, fill=fill)
return Video.new_like(self, output)
def adjust_brightness(self, brightness_factor: float) -> Video:
output = self._F.adjust_brightness_video(self, brightness_factor=brightness_factor)
return Video.new_like(self, output)
def adjust_saturation(self, saturation_factor: float) -> Video:
output = self._F.adjust_saturation_video(self, saturation_factor=saturation_factor)
return Video.new_like(self, output)
def adjust_contrast(self, contrast_factor: float) -> Video:
output = self._F.adjust_contrast_video(self, contrast_factor=contrast_factor)
return Video.new_like(self, output)
def adjust_sharpness(self, sharpness_factor: float) -> Video:
output = self._F.adjust_sharpness_video(self, sharpness_factor=sharpness_factor)
return Video.new_like(self, output)
def adjust_hue(self, hue_factor: float) -> Video:
output = self._F.adjust_hue_video(self, hue_factor=hue_factor)
return Video.new_like(self, output)
def adjust_gamma(self, gamma: float, gain: float = 1) -> Video:
output = self._F.adjust_gamma_video(self, gamma=gamma, gain=gain)
return Video.new_like(self, output)
def posterize(self, bits: int) -> Video:
output = self._F.posterize_video(self, bits=bits)
return Video.new_like(self, output)
def solarize(self, threshold: float) -> Video:
output = self._F.solarize_video(self, threshold=threshold)
return Video.new_like(self, output)
def autocontrast(self) -> Video:
output = self._F.autocontrast_video(self)
return Video.new_like(self, output)
def equalize(self) -> Video:
output = self._F.equalize_video(self)
return Video.new_like(self, output)
def invert(self) -> Video:
output = self._F.invert_video(self)
return Video.new_like(self, output)
def gaussian_blur(self, kernel_size: List[int], sigma: Optional[List[float]] = None) -> Video:
output = self._F.gaussian_blur_video(self, kernel_size=kernel_size, sigma=sigma)
return Video.new_like(self, output)
VideoType = Union[torch.Tensor, Video]
VideoTypeJIT = torch.Tensor
LegacyVideoType = torch.Tensor
LegacyVideoTypeJIT = torch.Tensor
TensorVideoType = Union[torch.Tensor, Video]
TensorVideoTypeJIT = torch.Tensor
ImageOrVideoType = Union[ImageType, VideoType]
ImageOrVideoTypeJIT = Union[ImageTypeJIT, VideoTypeJIT]
TensorImageOrVideoType = Union[TensorImageType, TensorVideoType]
TensorImageOrVideoTypeJIT = Union[TensorImageTypeJIT, TensorVideoTypeJIT]
...@@ -15,7 +15,7 @@ from ._utils import has_any, query_chw ...@@ -15,7 +15,7 @@ from ._utils import has_any, query_chw
class RandomErasing(_RandomApplyTransform): class RandomErasing(_RandomApplyTransform):
_transformed_types = (features.is_simple_tensor, features.Image, PIL.Image.Image) _transformed_types = (features.is_simple_tensor, features.Image, PIL.Image.Image, features.Video)
def __init__( def __init__(
self, self,
...@@ -92,7 +92,7 @@ class RandomErasing(_RandomApplyTransform): ...@@ -92,7 +92,7 @@ class RandomErasing(_RandomApplyTransform):
return dict(i=i, j=j, h=h, w=w, v=v) return dict(i=i, j=j, h=h, w=w, v=v)
def _transform(self, inpt: features.ImageType, params: Dict[str, Any]) -> features.ImageType: def _transform(self, inpt: features.ImageOrVideoType, params: Dict[str, Any]) -> features.ImageOrVideoType:
if params["v"] is not None: if params["v"] is not None:
inpt = F.erase(inpt, **params, inplace=self.inplace) inpt = F.erase(inpt, **params, inplace=self.inplace)
......
...@@ -31,40 +31,41 @@ class _AutoAugmentBase(Transform): ...@@ -31,40 +31,41 @@ class _AutoAugmentBase(Transform):
key = keys[int(torch.randint(len(keys), ()))] key = keys[int(torch.randint(len(keys), ()))]
return key, dct[key] return key, dct[key]
def _extract_image( def _extract_image_or_video(
self, self,
sample: Any, sample: Any,
unsupported_types: Tuple[Type, ...] = (features.BoundingBox, features.Mask), unsupported_types: Tuple[Type, ...] = (features.BoundingBox, features.Mask),
) -> Tuple[int, features.ImageType]: ) -> Tuple[int, features.ImageOrVideoType]:
sample_flat, _ = tree_flatten(sample) sample_flat, _ = tree_flatten(sample)
images = [] image_or_videos = []
for id, inpt in enumerate(sample_flat): for id, inpt in enumerate(sample_flat):
if _isinstance(inpt, (features.Image, PIL.Image.Image, features.is_simple_tensor)): if _isinstance(inpt, (features.Image, PIL.Image.Image, features.is_simple_tensor, features.Video)):
images.append((id, inpt)) image_or_videos.append((id, inpt))
elif isinstance(inpt, unsupported_types): elif isinstance(inpt, unsupported_types):
raise TypeError(f"Inputs of type {type(inpt).__name__} are not supported by {type(self).__name__}()") raise TypeError(f"Inputs of type {type(inpt).__name__} are not supported by {type(self).__name__}()")
if not images: if not image_or_videos:
raise TypeError("Found no image in the sample.") raise TypeError("Found no image in the sample.")
if len(images) > 1: if len(image_or_videos) > 1:
raise TypeError( raise TypeError(
f"Auto augment transformations are only properly defined for a single image, but found {len(images)}." f"Auto augment transformations are only properly defined for a single image or video, "
f"but found {len(image_or_videos)}."
) )
return images[0] return image_or_videos[0]
def _put_into_sample(self, sample: Any, id: int, item: Any) -> Any: def _put_into_sample(self, sample: Any, id: int, item: Any) -> Any:
sample_flat, spec = tree_flatten(sample) sample_flat, spec = tree_flatten(sample)
sample_flat[id] = item sample_flat[id] = item
return tree_unflatten(sample_flat, spec) return tree_unflatten(sample_flat, spec)
def _apply_image_transform( def _apply_image_or_video_transform(
self, self,
image: features.ImageType, image: features.ImageOrVideoType,
transform_id: str, transform_id: str,
magnitude: float, magnitude: float,
interpolation: InterpolationMode, interpolation: InterpolationMode,
fill: Dict[Type, features.FillType], fill: Dict[Type, features.FillType],
) -> features.ImageType: ) -> features.ImageOrVideoType:
fill_ = fill[type(image)] fill_ = fill[type(image)]
fill_ = F._geometry._convert_fill_arg(fill_) fill_ = F._geometry._convert_fill_arg(fill_)
...@@ -276,8 +277,8 @@ class AutoAugment(_AutoAugmentBase): ...@@ -276,8 +277,8 @@ class AutoAugment(_AutoAugmentBase):
def forward(self, *inputs: Any) -> Any: def forward(self, *inputs: Any) -> Any:
sample = inputs if len(inputs) > 1 else inputs[0] sample = inputs if len(inputs) > 1 else inputs[0]
id, image = self._extract_image(sample) id, image_or_video = self._extract_image_or_video(sample)
_, height, width = get_chw(image) _, height, width = get_chw(image_or_video)
policy = self._policies[int(torch.randint(len(self._policies), ()))] policy = self._policies[int(torch.randint(len(self._policies), ()))]
...@@ -295,11 +296,11 @@ class AutoAugment(_AutoAugmentBase): ...@@ -295,11 +296,11 @@ class AutoAugment(_AutoAugmentBase):
else: else:
magnitude = 0.0 magnitude = 0.0
image = self._apply_image_transform( image_or_video = self._apply_image_or_video_transform(
image, transform_id, magnitude, interpolation=self.interpolation, fill=self.fill image_or_video, transform_id, magnitude, interpolation=self.interpolation, fill=self.fill
) )
return self._put_into_sample(sample, id, image) return self._put_into_sample(sample, id, image_or_video)
class RandAugment(_AutoAugmentBase): class RandAugment(_AutoAugmentBase):
...@@ -347,8 +348,8 @@ class RandAugment(_AutoAugmentBase): ...@@ -347,8 +348,8 @@ class RandAugment(_AutoAugmentBase):
def forward(self, *inputs: Any) -> Any: def forward(self, *inputs: Any) -> Any:
sample = inputs if len(inputs) > 1 else inputs[0] sample = inputs if len(inputs) > 1 else inputs[0]
id, image = self._extract_image(sample) id, image_or_video = self._extract_image_or_video(sample)
_, height, width = get_chw(image) _, height, width = get_chw(image_or_video)
for _ in range(self.num_ops): for _ in range(self.num_ops):
transform_id, (magnitudes_fn, signed) = self._get_random_item(self._AUGMENTATION_SPACE) transform_id, (magnitudes_fn, signed) = self._get_random_item(self._AUGMENTATION_SPACE)
...@@ -359,11 +360,11 @@ class RandAugment(_AutoAugmentBase): ...@@ -359,11 +360,11 @@ class RandAugment(_AutoAugmentBase):
magnitude *= -1 magnitude *= -1
else: else:
magnitude = 0.0 magnitude = 0.0
image = self._apply_image_transform( image_or_video = self._apply_image_or_video_transform(
image, transform_id, magnitude, interpolation=self.interpolation, fill=self.fill image_or_video, transform_id, magnitude, interpolation=self.interpolation, fill=self.fill
) )
return self._put_into_sample(sample, id, image) return self._put_into_sample(sample, id, image_or_video)
class TrivialAugmentWide(_AutoAugmentBase): class TrivialAugmentWide(_AutoAugmentBase):
...@@ -401,8 +402,8 @@ class TrivialAugmentWide(_AutoAugmentBase): ...@@ -401,8 +402,8 @@ class TrivialAugmentWide(_AutoAugmentBase):
def forward(self, *inputs: Any) -> Any: def forward(self, *inputs: Any) -> Any:
sample = inputs if len(inputs) > 1 else inputs[0] sample = inputs if len(inputs) > 1 else inputs[0]
id, image = self._extract_image(sample) id, image_or_video = self._extract_image_or_video(sample)
_, height, width = get_chw(image) _, height, width = get_chw(image_or_video)
transform_id, (magnitudes_fn, signed) = self._get_random_item(self._AUGMENTATION_SPACE) transform_id, (magnitudes_fn, signed) = self._get_random_item(self._AUGMENTATION_SPACE)
...@@ -414,10 +415,10 @@ class TrivialAugmentWide(_AutoAugmentBase): ...@@ -414,10 +415,10 @@ class TrivialAugmentWide(_AutoAugmentBase):
else: else:
magnitude = 0.0 magnitude = 0.0
image = self._apply_image_transform( image_or_video = self._apply_image_or_video_transform(
image, transform_id, magnitude, interpolation=self.interpolation, fill=self.fill image_or_video, transform_id, magnitude, interpolation=self.interpolation, fill=self.fill
) )
return self._put_into_sample(sample, id, image) return self._put_into_sample(sample, id, image_or_video)
class AugMix(_AutoAugmentBase): class AugMix(_AutoAugmentBase):
...@@ -471,27 +472,28 @@ class AugMix(_AutoAugmentBase): ...@@ -471,27 +472,28 @@ class AugMix(_AutoAugmentBase):
def forward(self, *inputs: Any) -> Any: def forward(self, *inputs: Any) -> Any:
sample = inputs if len(inputs) > 1 else inputs[0] sample = inputs if len(inputs) > 1 else inputs[0]
id, orig_image = self._extract_image(sample) id, orig_image_or_video = self._extract_image_or_video(sample)
_, height, width = get_chw(orig_image) _, height, width = get_chw(orig_image_or_video)
if isinstance(orig_image, torch.Tensor): if isinstance(orig_image_or_video, torch.Tensor):
image = orig_image image_or_video = orig_image_or_video
else: # isinstance(inpt, PIL.Image.Image): else: # isinstance(inpt, PIL.Image.Image):
image = F.pil_to_tensor(orig_image) image_or_video = F.pil_to_tensor(orig_image_or_video)
augmentation_space = self._AUGMENTATION_SPACE if self.all_ops else self._PARTIAL_AUGMENTATION_SPACE augmentation_space = self._AUGMENTATION_SPACE if self.all_ops else self._PARTIAL_AUGMENTATION_SPACE
orig_dims = list(image.shape) orig_dims = list(image_or_video.shape)
batch = image.view([1] * max(4 - image.ndim, 0) + orig_dims) batch = image_or_video.view([1] * max(4 - image_or_video.ndim, 0) + orig_dims)
batch_dims = [batch.size(0)] + [1] * (batch.ndim - 1) batch_dims = [batch.size(0)] + [1] * (batch.ndim - 1)
# Sample the beta weights for combining the original and augmented image. To get Beta, we use a Dirichlet # Sample the beta weights for combining the original and augmented image or video. To get Beta, we use a
# with 2 parameters. The 1st column stores the weights of the original and the 2nd the ones of augmented image. # Dirichlet with 2 parameters. The 1st column stores the weights of the original and the 2nd the ones of
# augmented image or video.
m = self._sample_dirichlet( m = self._sample_dirichlet(
torch.tensor([self.alpha, self.alpha], device=batch.device).expand(batch_dims[0], -1) torch.tensor([self.alpha, self.alpha], device=batch.device).expand(batch_dims[0], -1)
) )
# Sample the mixing weights and combine them with the ones sampled from Beta for the augmented images. # Sample the mixing weights and combine them with the ones sampled from Beta for the augmented images or videos.
combined_weights = self._sample_dirichlet( combined_weights = self._sample_dirichlet(
torch.tensor([self.alpha] * self.mixture_width, device=batch.device).expand(batch_dims[0], -1) torch.tensor([self.alpha] * self.mixture_width, device=batch.device).expand(batch_dims[0], -1)
) * m[:, 1].view([batch_dims[0], -1]) ) * m[:, 1].view([batch_dims[0], -1])
...@@ -511,15 +513,15 @@ class AugMix(_AutoAugmentBase): ...@@ -511,15 +513,15 @@ class AugMix(_AutoAugmentBase):
else: else:
magnitude = 0.0 magnitude = 0.0
aug = self._apply_image_transform( aug = self._apply_image_or_video_transform(
aug, transform_id, magnitude, interpolation=self.interpolation, fill=self.fill aug, transform_id, magnitude, interpolation=self.interpolation, fill=self.fill
) )
mix.add_(combined_weights[:, i].view(batch_dims) * aug) mix.add_(combined_weights[:, i].view(batch_dims) * aug)
mix = mix.view(orig_dims).to(dtype=image.dtype) mix = mix.view(orig_dims).to(dtype=image_or_video.dtype)
if isinstance(orig_image, features.Image): if isinstance(orig_image_or_video, (features.Image, features.Video)):
mix = features.Image.new_like(orig_image, mix) mix = type(orig_image_or_video).new_like(orig_image_or_video, mix) # type: ignore[arg-type]
elif isinstance(orig_image, PIL.Image.Image): elif isinstance(orig_image_or_video, PIL.Image.Image):
mix = F.to_image_pil(mix) mix = F.to_image_pil(mix)
return self._put_into_sample(sample, id, mix) return self._put_into_sample(sample, id, mix)
...@@ -82,7 +82,7 @@ class ColorJitter(Transform): ...@@ -82,7 +82,7 @@ class ColorJitter(Transform):
class RandomPhotometricDistort(Transform): class RandomPhotometricDistort(Transform):
_transformed_types = (features.Image, PIL.Image.Image, features.is_simple_tensor) _transformed_types = (features.Image, PIL.Image.Image, features.is_simple_tensor, features.Video)
def __init__( def __init__(
self, self,
...@@ -110,20 +110,22 @@ class RandomPhotometricDistort(Transform): ...@@ -110,20 +110,22 @@ class RandomPhotometricDistort(Transform):
channel_permutation=torch.randperm(num_channels) if torch.rand(()) < self.p else None, channel_permutation=torch.randperm(num_channels) if torch.rand(()) < self.p else None,
) )
def _permute_channels(self, inpt: features.ImageType, permutation: torch.Tensor) -> features.ImageType: def _permute_channels(
self, inpt: features.ImageOrVideoType, permutation: torch.Tensor
) -> features.ImageOrVideoType:
if isinstance(inpt, PIL.Image.Image): if isinstance(inpt, PIL.Image.Image):
inpt = F.pil_to_tensor(inpt) inpt = F.pil_to_tensor(inpt)
output = inpt[..., permutation, :, :] output = inpt[..., permutation, :, :]
if isinstance(inpt, features.Image): if isinstance(inpt, (features.Image, features.Video)):
output = features.Image.new_like(inpt, output, color_space=features.ColorSpace.OTHER) output = type(inpt).new_like(inpt, output, color_space=features.ColorSpace.OTHER) # type: ignore[arg-type]
elif isinstance(inpt, PIL.Image.Image): elif isinstance(inpt, PIL.Image.Image):
output = F.to_image_pil(output) output = F.to_image_pil(output)
return output return output
def _transform(self, inpt: features.ImageType, params: Dict[str, Any]) -> features.ImageType: def _transform(self, inpt: features.ImageOrVideoType, params: Dict[str, Any]) -> features.ImageOrVideoType:
if params["brightness"]: if params["brightness"]:
inpt = F.adjust_brightness( inpt = F.adjust_brightness(
inpt, brightness_factor=ColorJitter._generate_value(self.brightness[0], self.brightness[1]) inpt, brightness_factor=ColorJitter._generate_value(self.brightness[0], self.brightness[1])
......
...@@ -855,8 +855,10 @@ class FixedSizeCrop(Transform): ...@@ -855,8 +855,10 @@ class FixedSizeCrop(Transform):
return inpt return inpt
def forward(self, *inputs: Any) -> Any: def forward(self, *inputs: Any) -> Any:
if not has_any(inputs, PIL.Image.Image, features.Image, features.is_simple_tensor): if not has_any(inputs, PIL.Image.Image, features.Image, features.is_simple_tensor, features.Video):
raise TypeError(f"{type(self).__name__}() requires input sample to contain an tensor or PIL image.") raise TypeError(
f"{type(self).__name__}() requires input sample to contain an tensor or PIL image or a Video."
)
if has_any(inputs, features.BoundingBox) and not has_any(inputs, features.Label, features.OneHotLabel): if has_any(inputs, features.BoundingBox) and not has_any(inputs, features.Label, features.OneHotLabel):
raise TypeError( raise TypeError(
......
...@@ -34,7 +34,7 @@ class ConvertImageDtype(Transform): ...@@ -34,7 +34,7 @@ class ConvertImageDtype(Transform):
class ConvertColorSpace(Transform): class ConvertColorSpace(Transform):
_transformed_types = (features.is_simple_tensor, features.Image, PIL.Image.Image) _transformed_types = (features.is_simple_tensor, features.Image, PIL.Image.Image, features.Video)
def __init__( def __init__(
self, self,
...@@ -54,7 +54,7 @@ class ConvertColorSpace(Transform): ...@@ -54,7 +54,7 @@ class ConvertColorSpace(Transform):
self.copy = copy self.copy = copy
def _transform(self, inpt: features.ImageType, params: Dict[str, Any]) -> features.ImageType: def _transform(self, inpt: features.ImageOrVideoType, params: Dict[str, Any]) -> features.ImageOrVideoType:
return F.convert_color_space( return F.convert_color_space(
inpt, color_space=self.color_space, old_color_space=self.old_color_space, copy=self.copy inpt, color_space=self.color_space, old_color_space=self.old_color_space, copy=self.copy
) )
......
...@@ -38,7 +38,7 @@ class Lambda(Transform): ...@@ -38,7 +38,7 @@ class Lambda(Transform):
class LinearTransformation(Transform): class LinearTransformation(Transform):
_transformed_types = (features.is_simple_tensor, features.Image) _transformed_types = (features.is_simple_tensor, features.Image, features.Video)
def __init__(self, transformation_matrix: torch.Tensor, mean_vector: torch.Tensor): def __init__(self, transformation_matrix: torch.Tensor, mean_vector: torch.Tensor):
super().__init__() super().__init__()
...@@ -68,7 +68,7 @@ class LinearTransformation(Transform): ...@@ -68,7 +68,7 @@ class LinearTransformation(Transform):
return super().forward(*inputs) return super().forward(*inputs)
def _transform(self, inpt: features.TensorImageType, params: Dict[str, Any]) -> torch.Tensor: def _transform(self, inpt: features.TensorImageOrVideoType, params: Dict[str, Any]) -> torch.Tensor:
# Image instance after linear transformation is not Image anymore due to unknown data range # Image instance after linear transformation is not Image anymore due to unknown data range
# Thus we will return Tensor for input Image # Thus we will return Tensor for input Image
...@@ -93,7 +93,7 @@ class LinearTransformation(Transform): ...@@ -93,7 +93,7 @@ class LinearTransformation(Transform):
class Normalize(Transform): class Normalize(Transform):
_transformed_types = (features.Image, features.is_simple_tensor) _transformed_types = (features.Image, features.is_simple_tensor, features.Video)
def __init__(self, mean: Sequence[float], std: Sequence[float], inplace: bool = False): def __init__(self, mean: Sequence[float], std: Sequence[float], inplace: bool = False):
super().__init__() super().__init__()
...@@ -101,7 +101,7 @@ class Normalize(Transform): ...@@ -101,7 +101,7 @@ class Normalize(Transform):
self.std = list(std) self.std = list(std)
self.inplace = inplace self.inplace = inplace
def _transform(self, inpt: features.TensorImageType, params: Dict[str, Any]) -> torch.Tensor: def _transform(self, inpt: features.TensorImageOrVideoType, params: Dict[str, Any]) -> torch.Tensor:
return F.normalize(inpt, mean=self.mean, std=self.std, inplace=self.inplace) return F.normalize(inpt, mean=self.mean, std=self.std, inplace=self.inplace)
def forward(self, *inpts: Any) -> Any: def forward(self, *inpts: Any) -> Any:
......
...@@ -82,10 +82,10 @@ def query_chw(sample: Any) -> Tuple[int, int, int]: ...@@ -82,10 +82,10 @@ def query_chw(sample: Any) -> Tuple[int, int, int]:
chws = { chws = {
get_chw(item) get_chw(item)
for item in flat_sample for item in flat_sample
if isinstance(item, (features.Image, PIL.Image.Image)) or features.is_simple_tensor(item) if isinstance(item, (features.Image, PIL.Image.Image, features.Video)) or features.is_simple_tensor(item)
} }
if not chws: if not chws:
raise TypeError("No image was found in the sample") raise TypeError("No image or video was found in the sample")
elif len(chws) > 1: elif len(chws) > 1:
raise ValueError(f"Found multiple CxHxW dimensions in the sample: {sequence_to_str(sorted(chws))}") raise ValueError(f"Found multiple CxHxW dimensions in the sample: {sequence_to_str(sorted(chws))}")
return chws.pop() return chws.pop()
......
...@@ -6,6 +6,7 @@ from ._meta import ( ...@@ -6,6 +6,7 @@ from ._meta import (
convert_format_bounding_box, convert_format_bounding_box,
convert_color_space_image_tensor, convert_color_space_image_tensor,
convert_color_space_image_pil, convert_color_space_image_pil,
convert_color_space_video,
convert_color_space, convert_color_space,
get_dimensions, get_dimensions,
get_image_num_channels, get_image_num_channels,
...@@ -13,41 +14,52 @@ from ._meta import ( ...@@ -13,41 +14,52 @@ from ._meta import (
get_spatial_size, get_spatial_size,
) # usort: skip ) # usort: skip
from ._augment import erase, erase_image_pil, erase_image_tensor from ._augment import erase, erase_image_pil, erase_image_tensor, erase_video
from ._color import ( from ._color import (
adjust_brightness, adjust_brightness,
adjust_brightness_image_pil, adjust_brightness_image_pil,
adjust_brightness_image_tensor, adjust_brightness_image_tensor,
adjust_brightness_video,
adjust_contrast, adjust_contrast,
adjust_contrast_image_pil, adjust_contrast_image_pil,
adjust_contrast_image_tensor, adjust_contrast_image_tensor,
adjust_contrast_video,
adjust_gamma, adjust_gamma,
adjust_gamma_image_pil, adjust_gamma_image_pil,
adjust_gamma_image_tensor, adjust_gamma_image_tensor,
adjust_gamma_video,
adjust_hue, adjust_hue,
adjust_hue_image_pil, adjust_hue_image_pil,
adjust_hue_image_tensor, adjust_hue_image_tensor,
adjust_hue_video,
adjust_saturation, adjust_saturation,
adjust_saturation_image_pil, adjust_saturation_image_pil,
adjust_saturation_image_tensor, adjust_saturation_image_tensor,
adjust_saturation_video,
adjust_sharpness, adjust_sharpness,
adjust_sharpness_image_pil, adjust_sharpness_image_pil,
adjust_sharpness_image_tensor, adjust_sharpness_image_tensor,
adjust_sharpness_video,
autocontrast, autocontrast,
autocontrast_image_pil, autocontrast_image_pil,
autocontrast_image_tensor, autocontrast_image_tensor,
autocontrast_video,
equalize, equalize,
equalize_image_pil, equalize_image_pil,
equalize_image_tensor, equalize_image_tensor,
equalize_video,
invert, invert,
invert_image_pil, invert_image_pil,
invert_image_tensor, invert_image_tensor,
invert_video,
posterize, posterize,
posterize_image_pil, posterize_image_pil,
posterize_image_tensor, posterize_image_tensor,
posterize_video,
solarize, solarize,
solarize_image_pil, solarize_image_pil,
solarize_image_tensor, solarize_image_tensor,
solarize_video,
) )
from ._geometry import ( from ._geometry import (
affine, affine,
...@@ -55,22 +67,26 @@ from ._geometry import ( ...@@ -55,22 +67,26 @@ from ._geometry import (
affine_image_pil, affine_image_pil,
affine_image_tensor, affine_image_tensor,
affine_mask, affine_mask,
affine_video,
center_crop, center_crop,
center_crop_bounding_box, center_crop_bounding_box,
center_crop_image_pil, center_crop_image_pil,
center_crop_image_tensor, center_crop_image_tensor,
center_crop_mask, center_crop_mask,
center_crop_video,
crop, crop,
crop_bounding_box, crop_bounding_box,
crop_image_pil, crop_image_pil,
crop_image_tensor, crop_image_tensor,
crop_mask, crop_mask,
crop_video,
elastic, elastic,
elastic_bounding_box, elastic_bounding_box,
elastic_image_pil, elastic_image_pil,
elastic_image_tensor, elastic_image_tensor,
elastic_mask, elastic_mask,
elastic_transform, elastic_transform,
elastic_video,
five_crop, five_crop,
five_crop_image_pil, five_crop_image_pil,
five_crop_image_tensor, five_crop_image_tensor,
...@@ -80,31 +96,37 @@ from ._geometry import ( ...@@ -80,31 +96,37 @@ from ._geometry import (
horizontal_flip_image_pil, horizontal_flip_image_pil,
horizontal_flip_image_tensor, horizontal_flip_image_tensor,
horizontal_flip_mask, horizontal_flip_mask,
horizontal_flip_video,
pad, pad,
pad_bounding_box, pad_bounding_box,
pad_image_pil, pad_image_pil,
pad_image_tensor, pad_image_tensor,
pad_mask, pad_mask,
pad_video,
perspective, perspective,
perspective_bounding_box, perspective_bounding_box,
perspective_image_pil, perspective_image_pil,
perspective_image_tensor, perspective_image_tensor,
perspective_mask, perspective_mask,
perspective_video,
resize, resize,
resize_bounding_box, resize_bounding_box,
resize_image_pil, resize_image_pil,
resize_image_tensor, resize_image_tensor,
resize_mask, resize_mask,
resize_video,
resized_crop, resized_crop,
resized_crop_bounding_box, resized_crop_bounding_box,
resized_crop_image_pil, resized_crop_image_pil,
resized_crop_image_tensor, resized_crop_image_tensor,
resized_crop_mask, resized_crop_mask,
resized_crop_video,
rotate, rotate,
rotate_bounding_box, rotate_bounding_box,
rotate_image_pil, rotate_image_pil,
rotate_image_tensor, rotate_image_tensor,
rotate_mask, rotate_mask,
rotate_video,
ten_crop, ten_crop,
ten_crop_image_pil, ten_crop_image_pil,
ten_crop_image_tensor, ten_crop_image_tensor,
...@@ -113,9 +135,18 @@ from ._geometry import ( ...@@ -113,9 +135,18 @@ from ._geometry import (
vertical_flip_image_pil, vertical_flip_image_pil,
vertical_flip_image_tensor, vertical_flip_image_tensor,
vertical_flip_mask, vertical_flip_mask,
vertical_flip_video,
vflip, vflip,
) )
from ._misc import gaussian_blur, gaussian_blur_image_pil, gaussian_blur_image_tensor, normalize, normalize_image_tensor from ._misc import (
gaussian_blur,
gaussian_blur_image_pil,
gaussian_blur_image_tensor,
gaussian_blur_video,
normalize,
normalize_image_tensor,
normalize_video,
)
from ._type_conversion import ( from ._type_conversion import (
convert_image_dtype, convert_image_dtype,
decode_image_with_pil, decode_image_with_pil,
......
...@@ -17,19 +17,25 @@ def erase_image_pil( ...@@ -17,19 +17,25 @@ def erase_image_pil(
return to_pil_image(output, mode=image.mode) return to_pil_image(output, mode=image.mode)
def erase_video(
video: torch.Tensor, i: int, j: int, h: int, w: int, v: torch.Tensor, inplace: bool = False
) -> torch.Tensor:
return erase_image_tensor(video, i=i, j=j, h=h, w=w, v=v, inplace=inplace)
def erase( def erase(
inpt: features.ImageTypeJIT, inpt: features.ImageOrVideoTypeJIT,
i: int, i: int,
j: int, j: int,
h: int, h: int,
w: int, w: int,
v: torch.Tensor, v: torch.Tensor,
inplace: bool = False, inplace: bool = False,
) -> features.ImageTypeJIT: ) -> features.ImageOrVideoTypeJIT:
if isinstance(inpt, torch.Tensor): if isinstance(inpt, torch.Tensor):
output = erase_image_tensor(inpt, i=i, j=j, h=h, w=w, v=v, inplace=inplace) output = erase_image_tensor(inpt, i=i, j=j, h=h, w=w, v=v, inplace=inplace)
if not torch.jit.is_scripting() and isinstance(inpt, features.Image): if not torch.jit.is_scripting() and isinstance(inpt, (features.Image, features.Video)):
output = features.Image.new_like(inpt, output) output = type(inpt).new_like(inpt, output) # type: ignore[arg-type]
return output return output
else: # isinstance(inpt, PIL.Image.Image): else: # isinstance(inpt, PIL.Image.Image):
return erase_image_pil(inpt, i=i, j=j, h=h, w=w, v=v, inplace=inplace) return erase_image_pil(inpt, i=i, j=j, h=h, w=w, v=v, inplace=inplace)
...@@ -2,10 +2,16 @@ import torch ...@@ -2,10 +2,16 @@ import torch
from torchvision.prototype import features from torchvision.prototype import features
from torchvision.transforms import functional_pil as _FP, functional_tensor as _FT from torchvision.transforms import functional_pil as _FP, functional_tensor as _FT
from ._meta import get_dimensions_image_tensor
adjust_brightness_image_tensor = _FT.adjust_brightness adjust_brightness_image_tensor = _FT.adjust_brightness
adjust_brightness_image_pil = _FP.adjust_brightness adjust_brightness_image_pil = _FP.adjust_brightness
def adjust_brightness_video(video: torch.Tensor, brightness_factor: float) -> torch.Tensor:
return adjust_brightness_image_tensor(video, brightness_factor=brightness_factor)
def adjust_brightness(inpt: features.InputTypeJIT, brightness_factor: float) -> features.InputTypeJIT: def adjust_brightness(inpt: features.InputTypeJIT, brightness_factor: float) -> features.InputTypeJIT:
if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, features._Feature)): if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, features._Feature)):
return adjust_brightness_image_tensor(inpt, brightness_factor=brightness_factor) return adjust_brightness_image_tensor(inpt, brightness_factor=brightness_factor)
...@@ -19,6 +25,10 @@ adjust_saturation_image_tensor = _FT.adjust_saturation ...@@ -19,6 +25,10 @@ adjust_saturation_image_tensor = _FT.adjust_saturation
adjust_saturation_image_pil = _FP.adjust_saturation adjust_saturation_image_pil = _FP.adjust_saturation
def adjust_saturation_video(video: torch.Tensor, saturation_factor: float) -> torch.Tensor:
return adjust_saturation_image_tensor(video, saturation_factor=saturation_factor)
def adjust_saturation(inpt: features.InputTypeJIT, saturation_factor: float) -> features.InputTypeJIT: def adjust_saturation(inpt: features.InputTypeJIT, saturation_factor: float) -> features.InputTypeJIT:
if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, features._Feature)): if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, features._Feature)):
return adjust_saturation_image_tensor(inpt, saturation_factor=saturation_factor) return adjust_saturation_image_tensor(inpt, saturation_factor=saturation_factor)
...@@ -32,6 +42,10 @@ adjust_contrast_image_tensor = _FT.adjust_contrast ...@@ -32,6 +42,10 @@ adjust_contrast_image_tensor = _FT.adjust_contrast
adjust_contrast_image_pil = _FP.adjust_contrast adjust_contrast_image_pil = _FP.adjust_contrast
def adjust_contrast_video(video: torch.Tensor, contrast_factor: float) -> torch.Tensor:
return adjust_contrast_image_tensor(video, contrast_factor=contrast_factor)
def adjust_contrast(inpt: features.InputTypeJIT, contrast_factor: float) -> features.InputTypeJIT: def adjust_contrast(inpt: features.InputTypeJIT, contrast_factor: float) -> features.InputTypeJIT:
if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, features._Feature)): if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, features._Feature)):
return adjust_contrast_image_tensor(inpt, contrast_factor=contrast_factor) return adjust_contrast_image_tensor(inpt, contrast_factor=contrast_factor)
...@@ -41,10 +55,40 @@ def adjust_contrast(inpt: features.InputTypeJIT, contrast_factor: float) -> feat ...@@ -41,10 +55,40 @@ def adjust_contrast(inpt: features.InputTypeJIT, contrast_factor: float) -> feat
return adjust_contrast_image_pil(inpt, contrast_factor=contrast_factor) return adjust_contrast_image_pil(inpt, contrast_factor=contrast_factor)
adjust_sharpness_image_tensor = _FT.adjust_sharpness def adjust_sharpness_image_tensor(image: torch.Tensor, sharpness_factor: float) -> torch.Tensor:
num_channels, height, width = get_dimensions_image_tensor(image)
if num_channels not in (1, 3):
raise TypeError(f"Input image tensor can have 1 or 3 channels, but found {num_channels}")
if sharpness_factor < 0:
raise ValueError(f"sharpness_factor ({sharpness_factor}) is not non-negative.")
if image.numel() == 0 or height <= 2 or width <= 2:
return image
shape = image.shape
if image.ndim > 4:
image = image.view(-1, num_channels, height, width)
needs_unsquash = True
else:
needs_unsquash = False
output = _FT._blend(image, _FT._blurred_degenerate_image(image), sharpness_factor)
if needs_unsquash:
output = output.view(shape)
return output
adjust_sharpness_image_pil = _FP.adjust_sharpness adjust_sharpness_image_pil = _FP.adjust_sharpness
def adjust_sharpness_video(video: torch.Tensor, sharpness_factor: float) -> torch.Tensor:
return adjust_sharpness_image_tensor(video, sharpness_factor=sharpness_factor)
def adjust_sharpness(inpt: features.InputTypeJIT, sharpness_factor: float) -> features.InputTypeJIT: def adjust_sharpness(inpt: features.InputTypeJIT, sharpness_factor: float) -> features.InputTypeJIT:
if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, features._Feature)): if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, features._Feature)):
return adjust_sharpness_image_tensor(inpt, sharpness_factor=sharpness_factor) return adjust_sharpness_image_tensor(inpt, sharpness_factor=sharpness_factor)
...@@ -58,6 +102,10 @@ adjust_hue_image_tensor = _FT.adjust_hue ...@@ -58,6 +102,10 @@ adjust_hue_image_tensor = _FT.adjust_hue
adjust_hue_image_pil = _FP.adjust_hue adjust_hue_image_pil = _FP.adjust_hue
def adjust_hue_video(video: torch.Tensor, hue_factor: float) -> torch.Tensor:
return adjust_hue_image_tensor(video, hue_factor=hue_factor)
def adjust_hue(inpt: features.InputTypeJIT, hue_factor: float) -> features.InputTypeJIT: def adjust_hue(inpt: features.InputTypeJIT, hue_factor: float) -> features.InputTypeJIT:
if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, features._Feature)): if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, features._Feature)):
return adjust_hue_image_tensor(inpt, hue_factor=hue_factor) return adjust_hue_image_tensor(inpt, hue_factor=hue_factor)
...@@ -71,6 +119,10 @@ adjust_gamma_image_tensor = _FT.adjust_gamma ...@@ -71,6 +119,10 @@ adjust_gamma_image_tensor = _FT.adjust_gamma
adjust_gamma_image_pil = _FP.adjust_gamma adjust_gamma_image_pil = _FP.adjust_gamma
def adjust_gamma_video(video: torch.Tensor, gamma: float, gain: float = 1) -> torch.Tensor:
return adjust_gamma_image_tensor(video, gamma=gamma, gain=gain)
def adjust_gamma(inpt: features.InputTypeJIT, gamma: float, gain: float = 1) -> features.InputTypeJIT: def adjust_gamma(inpt: features.InputTypeJIT, gamma: float, gain: float = 1) -> features.InputTypeJIT:
if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, features._Feature)): if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, features._Feature)):
return adjust_gamma_image_tensor(inpt, gamma=gamma, gain=gain) return adjust_gamma_image_tensor(inpt, gamma=gamma, gain=gain)
...@@ -84,6 +136,10 @@ posterize_image_tensor = _FT.posterize ...@@ -84,6 +136,10 @@ posterize_image_tensor = _FT.posterize
posterize_image_pil = _FP.posterize posterize_image_pil = _FP.posterize
def posterize_video(video: torch.Tensor, bits: int) -> torch.Tensor:
return posterize_image_tensor(video, bits=bits)
def posterize(inpt: features.InputTypeJIT, bits: int) -> features.InputTypeJIT: def posterize(inpt: features.InputTypeJIT, bits: int) -> features.InputTypeJIT:
if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, features._Feature)): if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, features._Feature)):
return posterize_image_tensor(inpt, bits=bits) return posterize_image_tensor(inpt, bits=bits)
...@@ -97,6 +153,10 @@ solarize_image_tensor = _FT.solarize ...@@ -97,6 +153,10 @@ solarize_image_tensor = _FT.solarize
solarize_image_pil = _FP.solarize solarize_image_pil = _FP.solarize
def solarize_video(video: torch.Tensor, threshold: float) -> torch.Tensor:
return solarize_image_tensor(video, threshold=threshold)
def solarize(inpt: features.InputTypeJIT, threshold: float) -> features.InputTypeJIT: def solarize(inpt: features.InputTypeJIT, threshold: float) -> features.InputTypeJIT:
if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, features._Feature)): if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, features._Feature)):
return solarize_image_tensor(inpt, threshold=threshold) return solarize_image_tensor(inpt, threshold=threshold)
...@@ -110,6 +170,10 @@ autocontrast_image_tensor = _FT.autocontrast ...@@ -110,6 +170,10 @@ autocontrast_image_tensor = _FT.autocontrast
autocontrast_image_pil = _FP.autocontrast autocontrast_image_pil = _FP.autocontrast
def autocontrast_video(video: torch.Tensor) -> torch.Tensor:
return autocontrast_image_tensor(video)
def autocontrast(inpt: features.InputTypeJIT) -> features.InputTypeJIT: def autocontrast(inpt: features.InputTypeJIT) -> features.InputTypeJIT:
if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, features._Feature)): if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, features._Feature)):
return autocontrast_image_tensor(inpt) return autocontrast_image_tensor(inpt)
...@@ -119,10 +183,35 @@ def autocontrast(inpt: features.InputTypeJIT) -> features.InputTypeJIT: ...@@ -119,10 +183,35 @@ def autocontrast(inpt: features.InputTypeJIT) -> features.InputTypeJIT:
return autocontrast_image_pil(inpt) return autocontrast_image_pil(inpt)
equalize_image_tensor = _FT.equalize def equalize_image_tensor(image: torch.Tensor) -> torch.Tensor:
if image.dtype != torch.uint8:
raise TypeError(f"Only torch.uint8 image tensors are supported, but found {image.dtype}")
num_channels, height, width = get_dimensions_image_tensor(image)
if num_channels not in (1, 3):
raise TypeError(f"Input image tensor can have 1 or 3 channels, but found {num_channels}")
if image.numel() == 0:
return image
elif image.ndim == 2:
return _FT._scale_channel(image)
else:
return torch.stack(
[
# TODO: when merging transforms v1 and v2, we can inline this function call
_FT._equalize_single_image(single_image)
for single_image in image.view(-1, num_channels, height, width)
]
).view(image.shape)
equalize_image_pil = _FP.equalize equalize_image_pil = _FP.equalize
def equalize_video(video: torch.Tensor) -> torch.Tensor:
return equalize_image_tensor(video)
def equalize(inpt: features.InputTypeJIT) -> features.InputTypeJIT: def equalize(inpt: features.InputTypeJIT) -> features.InputTypeJIT:
if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, features._Feature)): if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, features._Feature)):
return equalize_image_tensor(inpt) return equalize_image_tensor(inpt)
...@@ -136,6 +225,10 @@ invert_image_tensor = _FT.invert ...@@ -136,6 +225,10 @@ invert_image_tensor = _FT.invert
invert_image_pil = _FP.invert invert_image_pil = _FP.invert
def invert_video(video: torch.Tensor) -> torch.Tensor:
return invert_image_tensor(video)
def invert(inpt: features.InputTypeJIT) -> features.InputTypeJIT: def invert(inpt: features.InputTypeJIT) -> features.InputTypeJIT:
if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, features._Feature)): if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, features._Feature)):
return invert_image_tensor(inpt) return invert_image_tensor(inpt)
......
...@@ -47,6 +47,10 @@ def horizontal_flip_bounding_box( ...@@ -47,6 +47,10 @@ def horizontal_flip_bounding_box(
).view(shape) ).view(shape)
def horizontal_flip_video(video: torch.Tensor) -> torch.Tensor:
return horizontal_flip_image_tensor(video)
def horizontal_flip(inpt: features.InputTypeJIT) -> features.InputTypeJIT: def horizontal_flip(inpt: features.InputTypeJIT) -> features.InputTypeJIT:
if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, features._Feature)): if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, features._Feature)):
return horizontal_flip_image_tensor(inpt) return horizontal_flip_image_tensor(inpt)
...@@ -80,6 +84,10 @@ def vertical_flip_bounding_box( ...@@ -80,6 +84,10 @@ def vertical_flip_bounding_box(
).view(shape) ).view(shape)
def vertical_flip_video(video: torch.Tensor) -> torch.Tensor:
return vertical_flip_image_tensor(video)
def vertical_flip(inpt: features.InputTypeJIT) -> features.InputTypeJIT: def vertical_flip(inpt: features.InputTypeJIT) -> features.InputTypeJIT:
if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, features._Feature)): if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, features._Feature)):
return vertical_flip_image_tensor(inpt) return vertical_flip_image_tensor(inpt)
...@@ -185,6 +193,16 @@ def resize_bounding_box( ...@@ -185,6 +193,16 @@ def resize_bounding_box(
) )
def resize_video(
video: torch.Tensor,
size: List[int],
interpolation: InterpolationMode = InterpolationMode.BILINEAR,
max_size: Optional[int] = None,
antialias: bool = False,
) -> torch.Tensor:
return resize_image_tensor(video, size=size, interpolation=interpolation, max_size=max_size, antialias=antialias)
def resize( def resize(
inpt: features.InputTypeJIT, inpt: features.InputTypeJIT,
size: List[int], size: List[int],
...@@ -441,6 +459,28 @@ def affine_mask( ...@@ -441,6 +459,28 @@ def affine_mask(
return output return output
def affine_video(
video: torch.Tensor,
angle: Union[int, float],
translate: List[float],
scale: float,
shear: List[float],
interpolation: InterpolationMode = InterpolationMode.NEAREST,
fill: features.FillTypeJIT = None,
center: Optional[List[float]] = None,
) -> torch.Tensor:
return affine_image_tensor(
video,
angle=angle,
translate=translate,
scale=scale,
shear=shear,
interpolation=interpolation,
fill=fill,
center=center,
)
def _convert_fill_arg(fill: features.FillType) -> features.FillTypeJIT: def _convert_fill_arg(fill: features.FillType) -> features.FillTypeJIT:
# Fill = 0 is not equivalent to None, https://github.com/pytorch/vision/issues/6517 # Fill = 0 is not equivalent to None, https://github.com/pytorch/vision/issues/6517
# So, we can't reassign fill to 0 # So, we can't reassign fill to 0
...@@ -614,6 +654,17 @@ def rotate_mask( ...@@ -614,6 +654,17 @@ def rotate_mask(
return output return output
def rotate_video(
video: torch.Tensor,
angle: float,
interpolation: InterpolationMode = InterpolationMode.NEAREST,
expand: bool = False,
fill: features.FillTypeJIT = None,
center: Optional[List[float]] = None,
) -> torch.Tensor:
return rotate_image_tensor(video, angle, interpolation=interpolation, expand=expand, fill=fill, center=center)
def rotate( def rotate(
inpt: features.InputTypeJIT, inpt: features.InputTypeJIT,
angle: float, angle: float,
...@@ -751,6 +802,15 @@ def pad_bounding_box( ...@@ -751,6 +802,15 @@ def pad_bounding_box(
return bounding_box, (height, width) return bounding_box, (height, width)
def pad_video(
video: torch.Tensor,
padding: Union[int, List[int]],
fill: features.FillTypeJIT = None,
padding_mode: str = "constant",
) -> torch.Tensor:
return pad_image_tensor(video, padding, fill=fill, padding_mode=padding_mode)
def pad( def pad(
inpt: features.InputTypeJIT, inpt: features.InputTypeJIT,
padding: Union[int, List[int]], padding: Union[int, List[int]],
...@@ -798,6 +858,10 @@ def crop_mask(mask: torch.Tensor, top: int, left: int, height: int, width: int) ...@@ -798,6 +858,10 @@ def crop_mask(mask: torch.Tensor, top: int, left: int, height: int, width: int)
return crop_image_tensor(mask, top, left, height, width) return crop_image_tensor(mask, top, left, height, width)
def crop_video(video: torch.Tensor, top: int, left: int, height: int, width: int) -> torch.Tensor:
return crop_image_tensor(video, top, left, height, width)
def crop(inpt: features.InputTypeJIT, top: int, left: int, height: int, width: int) -> features.InputTypeJIT: def crop(inpt: features.InputTypeJIT, top: int, left: int, height: int, width: int) -> features.InputTypeJIT:
if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, features._Feature)): if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, features._Feature)):
return crop_image_tensor(inpt, top, left, height, width) return crop_image_tensor(inpt, top, left, height, width)
...@@ -932,6 +996,33 @@ def perspective_mask( ...@@ -932,6 +996,33 @@ def perspective_mask(
return output return output
def perspective_video(
video: torch.Tensor,
perspective_coeffs: List[float],
interpolation: InterpolationMode = InterpolationMode.BILINEAR,
fill: features.FillTypeJIT = None,
) -> torch.Tensor:
# TODO: this is a temporary workaround until the image kernel supports arbitrary batch sizes. Remove this when
# https://github.com/pytorch/vision/issues/6670 is resolved.
if video.numel() == 0:
return video
shape = video.shape
if video.ndim > 4:
video = video.view((-1,) + shape[-3:])
needs_unsquash = True
else:
needs_unsquash = False
output = perspective_image_tensor(video, perspective_coeffs, interpolation=interpolation, fill=fill)
if needs_unsquash:
output = output.view(shape)
return output
def perspective( def perspective(
inpt: features.InputTypeJIT, inpt: features.InputTypeJIT,
perspective_coeffs: List[float], perspective_coeffs: List[float],
...@@ -1026,6 +1117,33 @@ def elastic_mask( ...@@ -1026,6 +1117,33 @@ def elastic_mask(
return output return output
def elastic_video(
video: torch.Tensor,
displacement: torch.Tensor,
interpolation: InterpolationMode = InterpolationMode.BILINEAR,
fill: features.FillTypeJIT = None,
) -> torch.Tensor:
# TODO: this is a temporary workaround until the image kernel supports arbitrary batch sizes. Remove this when
# https://github.com/pytorch/vision/issues/6670 is resolved.
if video.numel() == 0:
return video
shape = video.shape
if video.ndim > 4:
video = video.view((-1,) + shape[-3:])
needs_unsquash = True
else:
needs_unsquash = False
output = elastic_image_tensor(video, displacement, interpolation=interpolation, fill=fill)
if needs_unsquash:
output = output.view(shape)
return output
def elastic( def elastic(
inpt: features.InputTypeJIT, inpt: features.InputTypeJIT,
displacement: torch.Tensor, displacement: torch.Tensor,
...@@ -1128,6 +1246,10 @@ def center_crop_mask(mask: torch.Tensor, output_size: List[int]) -> torch.Tensor ...@@ -1128,6 +1246,10 @@ def center_crop_mask(mask: torch.Tensor, output_size: List[int]) -> torch.Tensor
return output return output
def center_crop_video(video: torch.Tensor, output_size: List[int]) -> torch.Tensor:
return center_crop_image_tensor(video, output_size)
def center_crop(inpt: features.InputTypeJIT, output_size: List[int]) -> features.InputTypeJIT: def center_crop(inpt: features.InputTypeJIT, output_size: List[int]) -> features.InputTypeJIT:
if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, features._Feature)): if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, features._Feature)):
return center_crop_image_tensor(inpt, output_size) return center_crop_image_tensor(inpt, output_size)
...@@ -1190,6 +1312,21 @@ def resized_crop_mask( ...@@ -1190,6 +1312,21 @@ def resized_crop_mask(
return resize_mask(mask, size) return resize_mask(mask, size)
def resized_crop_video(
video: torch.Tensor,
top: int,
left: int,
height: int,
width: int,
size: List[int],
interpolation: InterpolationMode = InterpolationMode.BILINEAR,
antialias: bool = False,
) -> torch.Tensor:
return resized_crop_image_tensor(
video, top, left, height, width, antialias=antialias, size=size, interpolation=interpolation
)
def resized_crop( def resized_crop(
inpt: features.InputTypeJIT, inpt: features.InputTypeJIT,
top: int, top: int,
......
...@@ -11,10 +11,12 @@ get_dimensions_image_pil = _FP.get_dimensions ...@@ -11,10 +11,12 @@ get_dimensions_image_pil = _FP.get_dimensions
# TODO: Should this be prefixed with `_` similar to other methods that don't get exposed by init? # TODO: Should this be prefixed with `_` similar to other methods that don't get exposed by init?
def get_chw(image: features.ImageTypeJIT) -> Tuple[int, int, int]: def get_chw(image: features.ImageOrVideoTypeJIT) -> Tuple[int, int, int]:
if isinstance(image, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(image, features.Image)): if isinstance(image, torch.Tensor) and (
torch.jit.is_scripting() or not isinstance(image, (features.Image, features.Video))
):
channels, height, width = get_dimensions_image_tensor(image) channels, height, width = get_dimensions_image_tensor(image)
elif isinstance(image, features.Image): elif isinstance(image, (features.Image, features.Video)):
channels = image.num_channels channels = image.num_channels
height, width = image.image_size height, width = image.image_size
else: # isinstance(image, PIL.Image.Image) else: # isinstance(image, PIL.Image.Image)
...@@ -29,11 +31,11 @@ def get_chw(image: features.ImageTypeJIT) -> Tuple[int, int, int]: ...@@ -29,11 +31,11 @@ def get_chw(image: features.ImageTypeJIT) -> Tuple[int, int, int]:
# detailed above. # detailed above.
def get_dimensions(image: features.ImageTypeJIT) -> List[int]: def get_dimensions(image: features.ImageOrVideoTypeJIT) -> List[int]:
return list(get_chw(image)) return list(get_chw(image))
def get_num_channels(image: features.ImageTypeJIT) -> int: def get_num_channels(image: features.ImageOrVideoTypeJIT) -> int:
num_channels, *_ = get_chw(image) num_channels, *_ = get_chw(image)
return num_channels return num_channels
...@@ -43,7 +45,7 @@ def get_num_channels(image: features.ImageTypeJIT) -> int: ...@@ -43,7 +45,7 @@ def get_num_channels(image: features.ImageTypeJIT) -> int:
get_image_num_channels = get_num_channels get_image_num_channels = get_num_channels
def get_spatial_size(image: features.ImageTypeJIT) -> List[int]: def get_spatial_size(image: features.ImageOrVideoTypeJIT) -> List[int]:
_, *size = get_chw(image) _, *size = get_chw(image)
return size return size
...@@ -207,13 +209,23 @@ def convert_color_space_image_pil( ...@@ -207,13 +209,23 @@ def convert_color_space_image_pil(
return image.convert(new_mode) return image.convert(new_mode)
def convert_color_space_video(
video: torch.Tensor, old_color_space: ColorSpace, new_color_space: ColorSpace, copy: bool = True
) -> torch.Tensor:
return convert_color_space_image_tensor(
video, old_color_space=old_color_space, new_color_space=new_color_space, copy=copy
)
def convert_color_space( def convert_color_space(
inpt: features.ImageTypeJIT, inpt: features.ImageOrVideoTypeJIT,
color_space: ColorSpace, color_space: ColorSpace,
old_color_space: Optional[ColorSpace] = None, old_color_space: Optional[ColorSpace] = None,
copy: bool = True, copy: bool = True,
) -> features.ImageTypeJIT: ) -> features.ImageOrVideoTypeJIT:
if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, features.Image)): if isinstance(inpt, torch.Tensor) and (
torch.jit.is_scripting() or not isinstance(inpt, (features.Image, features.Video))
):
if old_color_space is None: if old_color_space is None:
raise RuntimeError( raise RuntimeError(
"In order to convert the color space of simple tensor images, " "In order to convert the color space of simple tensor images, "
...@@ -222,7 +234,7 @@ def convert_color_space( ...@@ -222,7 +234,7 @@ def convert_color_space(
return convert_color_space_image_tensor( return convert_color_space_image_tensor(
inpt, old_color_space=old_color_space, new_color_space=color_space, copy=copy inpt, old_color_space=old_color_space, new_color_space=color_space, copy=copy
) )
elif isinstance(inpt, features.Image): elif isinstance(inpt, (features.Image, features.Video)):
return inpt.to_color_space(color_space, copy=copy) return inpt.to_color_space(color_space, copy=copy)
else: else:
return cast(features.ImageTypeJIT, convert_color_space_image_pil(inpt, color_space, copy=copy)) return cast(features.ImageOrVideoTypeJIT, convert_color_space_image_pil(inpt, color_space, copy=copy))
...@@ -9,18 +9,22 @@ from torchvision.transforms.functional import pil_to_tensor, to_pil_image ...@@ -9,18 +9,22 @@ from torchvision.transforms.functional import pil_to_tensor, to_pil_image
normalize_image_tensor = _FT.normalize normalize_image_tensor = _FT.normalize
def normalize_video(video: torch.Tensor, mean: List[float], std: List[float], inplace: bool = False) -> torch.Tensor:
return normalize_image_tensor(video, mean, std, inplace=inplace)
def normalize( def normalize(
inpt: features.TensorImageTypeJIT, mean: List[float], std: List[float], inplace: bool = False inpt: features.TensorImageOrVideoTypeJIT, mean: List[float], std: List[float], inplace: bool = False
) -> torch.Tensor: ) -> torch.Tensor:
if torch.jit.is_scripting(): if torch.jit.is_scripting():
correct_type = isinstance(inpt, torch.Tensor) correct_type = isinstance(inpt, torch.Tensor)
else: else:
correct_type = features.is_simple_tensor(inpt) or isinstance(inpt, features.Image) correct_type = features.is_simple_tensor(inpt) or isinstance(inpt, (features.Image, features.Video))
inpt = inpt.as_subclass(torch.Tensor) inpt = inpt.as_subclass(torch.Tensor)
if not correct_type: if not correct_type:
raise TypeError(f"img should be Tensor Image. Got {type(inpt)}") raise TypeError(f"img should be Tensor Image. Got {type(inpt)}")
# Image instance after normalization is not Image anymore due to unknown data range # Image or Video type should not be retained after normalization due to unknown data range
# Thus we return Tensor for input Image # Thus we return Tensor for input Image
return normalize_image_tensor(inpt, mean=mean, std=std, inplace=inplace) return normalize_image_tensor(inpt, mean=mean, std=std, inplace=inplace)
...@@ -64,6 +68,30 @@ def gaussian_blur_image_pil( ...@@ -64,6 +68,30 @@ def gaussian_blur_image_pil(
return to_pil_image(output, mode=image.mode) return to_pil_image(output, mode=image.mode)
def gaussian_blur_video(
video: torch.Tensor, kernel_size: List[int], sigma: Optional[List[float]] = None
) -> torch.Tensor:
# TODO: this is a temporary workaround until the image kernel supports arbitrary batch sizes. Remove this when
# https://github.com/pytorch/vision/issues/6670 is resolved.
if video.numel() == 0:
return video
shape = video.shape
if video.ndim > 4:
video = video.view((-1,) + shape[-3:])
needs_unsquash = True
else:
needs_unsquash = False
output = gaussian_blur_image_tensor(video, kernel_size, sigma)
if needs_unsquash:
output = output.view(shape)
return output
def gaussian_blur( def gaussian_blur(
inpt: features.InputTypeJIT, kernel_size: List[int], sigma: Optional[List[float]] = None inpt: features.InputTypeJIT, kernel_size: List[int], sigma: Optional[List[float]] = None
) -> features.InputTypeJIT: ) -> features.InputTypeJIT:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment