".github/vscode:/vscode.git/clone" did not exist on "c3d78cd3067612175ac9f0f8b234abf5a2e1f510"
Unverified Commit e96860d6 authored by Vasilis Vryniotis's avatar Vasilis Vryniotis Committed by GitHub
Browse files

Adding Uniform temporal Subsampling for Video (#6812)



* Adding temporal sampling kernel and dispatcher.

* Adding the UniformTemporalSubsample class.

* Add it on init

* Adding tests.

* Addressing comments.

* Reverting proposal as it led to different results.

* add more tests for uniform_temporal_subsample

* cleanup

* fix logic

* fix logic

* make test more strict

* lint

* Update torchvision/prototype/transforms/functional/_temporal.py

* remove pytorchvideo again per request
Co-authored-by: default avatarPhilip Meier <github.pmeier@posteo.de>
parent 9c112935
......@@ -426,4 +426,13 @@ DISPATCHER_INFOS = [
skip_dispatch_feature,
],
),
DispatcherInfo(
F.uniform_temporal_subsample,
kernels={
features.Video: F.uniform_temporal_subsample_video,
},
test_marks=[
skip_dispatch_feature,
],
),
]
......@@ -2100,3 +2100,43 @@ KERNEL_INFOS.extend(
),
]
)
def sample_inputs_uniform_temporal_subsample_video():
for video_loader in make_video_loaders(sizes=["random"], num_frames=[4]):
for temporal_dim in [-4, len(video_loader.shape) - 4]:
yield ArgsKwargs(video_loader, num_samples=2, temporal_dim=temporal_dim)
def reference_uniform_temporal_subsample_video(x, num_samples, temporal_dim=-4):
# Copy-pasted from
# https://github.com/facebookresearch/pytorchvideo/blob/c8d23d8b7e597586a9e2d18f6ed31ad8aa379a7a/pytorchvideo/transforms/functional.py#L19
t = x.shape[temporal_dim]
assert num_samples > 0 and t > 0
# Sample by nearest neighbor interpolation if num_samples > t.
indices = torch.linspace(0, t - 1, num_samples)
indices = torch.clamp(indices, 0, t - 1).long()
return torch.index_select(x, temporal_dim, indices)
def reference_inputs_uniform_temporal_subsample_video():
for video_loader in make_video_loaders(sizes=["random"], color_spaces=[features.ColorSpace.RGB], num_frames=[10]):
for num_samples in range(1, video_loader.shape[-4] + 1):
yield ArgsKwargs(video_loader, num_samples)
KERNEL_INFOS.append(
KernelInfo(
F.uniform_temporal_subsample_video,
sample_inputs_fn=sample_inputs_uniform_temporal_subsample_video,
reference_fn=reference_uniform_temporal_subsample_video,
reference_inputs_fn=reference_inputs_uniform_temporal_subsample_video,
test_marks=[
TestMark(
("TestKernels", "test_batched_vs_single"),
pytest.mark.skip("Positive `temporal_dim` arguments are not equivalent for batched and single inputs"),
condition=lambda args_kwargs: args_kwargs.kwargs.get("temporal_dim") >= 0,
),
],
)
)
......@@ -1899,3 +1899,22 @@ def test_transpose_dimensions(dims):
assert type(transformed_value) == torch.Tensor
else:
assert transformed_value is value
class TestUniformTemporalSubsample:
@pytest.mark.parametrize(
"inpt",
[
torch.zeros(10, 3, 8, 8),
torch.zeros(1, 10, 3, 8, 8),
features.Video(torch.zeros(1, 10, 3, 8, 8)),
],
)
def test__transform(self, inpt):
num_samples = 5
transform = transforms.UniformTemporalSubsample(num_samples)
output = transform(inpt)
assert type(output) is type(inpt)
assert output.shape[-4] == num_samples
assert output.dtype == inpt.dtype
......@@ -242,6 +242,7 @@ class TestDispatchers:
F.get_num_frames,
F.get_spatial_size,
F.rgb_to_grayscale,
F.uniform_temporal_subsample,
],
ids=lambda dispatcher: dispatcher.__name__,
)
......@@ -1060,3 +1061,13 @@ def test_equalize_image_tensor_edge_cases():
inpt[..., 100:, 100:] = 1
output = F.equalize_image_tensor(inpt)
assert output.unique().tolist() == [0, 255]
@pytest.mark.parametrize("device", cpu_and_gpu())
def test_correctness_uniform_temporal_subsample(device):
video = torch.arange(10, device=device)[:, None, None, None].expand(-1, 3, 8, 8)
out_video = F.uniform_temporal_subsample(video, 5)
assert out_video.unique().tolist() == [0, 2, 4, 6, 9]
out_video = F.uniform_temporal_subsample(video, 8)
assert out_video.unique().tolist() == [0, 1, 2, 3, 5, 6, 7, 9]
......@@ -51,6 +51,7 @@ from ._misc import (
ToDtype,
TransposeDimensions,
)
from ._temporal import UniformTemporalSubsample
from ._type_conversion import DecodeImage, LabelToOneHot, PILToTensor, ToImagePIL, ToImageTensor, ToPILImage
from ._deprecated import Grayscale, RandomGrayscale, ToTensor # usort: skip
from typing import Any, Dict
from torchvision.prototype import features
from torchvision.prototype.transforms import functional as F, Transform
class UniformTemporalSubsample(Transform):
_transformed_types = (features.is_simple_tensor, features.Video)
def __init__(self, num_samples: int, temporal_dim: int = -4):
super().__init__()
self.num_samples = num_samples
self.temporal_dim = temporal_dim
def _transform(self, inpt: features.VideoType, params: Dict[str, Any]) -> features.VideoType:
return F.uniform_temporal_subsample(inpt, self.num_samples, temporal_dim=self.temporal_dim)
......@@ -165,6 +165,7 @@ from ._misc import (
normalize_image_tensor,
normalize_video,
)
from ._temporal import uniform_temporal_subsample, uniform_temporal_subsample_video
from ._type_conversion import (
decode_image_with_pil,
decode_video_with_av,
......
import torch
from torchvision.prototype import features
def uniform_temporal_subsample_video(video: torch.Tensor, num_samples: int, temporal_dim: int = -4) -> torch.Tensor:
# Reference: https://github.com/facebookresearch/pytorchvideo/blob/a0a131e/pytorchvideo/transforms/functional.py#L19
t_max = video.shape[temporal_dim] - 1
indices = torch.linspace(0, t_max, num_samples, device=video.device).long()
return torch.index_select(video, temporal_dim, indices)
def uniform_temporal_subsample(
inpt: features.VideoTypeJIT, num_samples: int, temporal_dim: int = -4
) -> features.VideoTypeJIT:
if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, features.Video)):
return uniform_temporal_subsample_video(inpt, num_samples, temporal_dim=temporal_dim)
else: # isinstance(inpt, features.Video)
if temporal_dim != -4 and inpt.ndim - 4 != temporal_dim:
raise ValueError("Video inputs must have temporal_dim equivalent to -4")
output = uniform_temporal_subsample_video(
inpt.as_subclass(torch.Tensor), num_samples, temporal_dim=temporal_dim
)
return features.Video.wrap_like(inpt, output)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment