Unverified Commit 3991ab99 authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

Promote prototype transforms to beta status (#7261)


Co-authored-by: default avatarNicolas Hug <contact@nicolas-hug.com>
Co-authored-by: default avatarvfdev-5 <vfdev.5@gmail.com>
parent d010e82f
import torch
from torchvision.prototype import datapoints
from torchvision import datapoints
from torchvision.utils import _log_api_usage_once
from ._utils import is_simple_tensor
def uniform_temporal_subsample_video(video: torch.Tensor, num_samples: int, temporal_dim: int = -4) -> torch.Tensor:
def uniform_temporal_subsample_video(video: torch.Tensor, num_samples: int) -> torch.Tensor:
# Reference: https://github.com/facebookresearch/pytorchvideo/blob/a0a131e/pytorchvideo/transforms/functional.py#L19
t_max = video.shape[temporal_dim] - 1
t_max = video.shape[-4] - 1
indices = torch.linspace(0, t_max, num_samples, device=video.device).long()
return torch.index_select(video, temporal_dim, indices)
return torch.index_select(video, -4, indices)
def uniform_temporal_subsample(
inpt: datapoints.VideoTypeJIT, num_samples: int, temporal_dim: int = -4
) -> datapoints.VideoTypeJIT:
def uniform_temporal_subsample(inpt: datapoints.VideoTypeJIT, num_samples: int) -> datapoints.VideoTypeJIT:
if not torch.jit.is_scripting():
_log_api_usage_once(uniform_temporal_subsample)
if torch.jit.is_scripting() or is_simple_tensor(inpt):
return uniform_temporal_subsample_video(inpt, num_samples, temporal_dim=temporal_dim)
return uniform_temporal_subsample_video(inpt, num_samples)
elif isinstance(inpt, datapoints.Video):
if temporal_dim != -4 and inpt.ndim - 4 != temporal_dim:
raise ValueError("Video inputs must have temporal_dim equivalent to -4")
output = uniform_temporal_subsample_video(
inpt.as_subclass(torch.Tensor), num_samples, temporal_dim=temporal_dim
)
output = uniform_temporal_subsample_video(inpt.as_subclass(torch.Tensor), num_samples)
return datapoints.Video.wrap_like(inpt, output)
else:
raise TypeError(f"Input can either be a plain tensor or a `Video` datapoint, but got {type(inpt)} instead.")
......@@ -3,7 +3,7 @@ from typing import Union
import numpy as np
import PIL.Image
import torch
from torchvision.prototype import datapoints
from torchvision import datapoints
from torchvision.transforms import functional as _F
......
from typing import Any
import torch
from torchvision.prototype.datapoints._datapoint import Datapoint
from torchvision.datapoints._datapoint import Datapoint
def is_simple_tensor(inpt: Any) -> bool:
......
......@@ -3,10 +3,10 @@ from __future__ import annotations
from typing import Any, Callable, List, Tuple, Type, Union
import PIL.Image
from torchvision import datapoints
from torchvision._utils import sequence_to_str
from torchvision.prototype import datapoints
from torchvision.prototype.transforms.functional import get_dimensions, get_spatial_size, is_simple_tensor
from torchvision.transforms.v2.functional import get_dimensions, get_spatial_size, is_simple_tensor
def query_bounding_box(flat_inputs: List[Any]) -> datapoints.BoundingBox:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment