Unverified Commit 3991ab99 authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

Promote prototype transforms to beta status (#7261)


Co-authored-by: default avatarNicolas Hug <contact@nicolas-hug.com>
Co-authored-by: default avatarvfdev-5 <vfdev.5@gmail.com>
parent d010e82f
import torch import torch
from torchvision.prototype import datapoints from torchvision import datapoints
from torchvision.utils import _log_api_usage_once from torchvision.utils import _log_api_usage_once
from ._utils import is_simple_tensor from ._utils import is_simple_tensor
def uniform_temporal_subsample_video(video: torch.Tensor, num_samples: int, temporal_dim: int = -4) -> torch.Tensor: def uniform_temporal_subsample_video(video: torch.Tensor, num_samples: int) -> torch.Tensor:
# Reference: https://github.com/facebookresearch/pytorchvideo/blob/a0a131e/pytorchvideo/transforms/functional.py#L19 # Reference: https://github.com/facebookresearch/pytorchvideo/blob/a0a131e/pytorchvideo/transforms/functional.py#L19
t_max = video.shape[temporal_dim] - 1 t_max = video.shape[-4] - 1
indices = torch.linspace(0, t_max, num_samples, device=video.device).long() indices = torch.linspace(0, t_max, num_samples, device=video.device).long()
return torch.index_select(video, temporal_dim, indices) return torch.index_select(video, -4, indices)
def uniform_temporal_subsample( def uniform_temporal_subsample(inpt: datapoints.VideoTypeJIT, num_samples: int) -> datapoints.VideoTypeJIT:
inpt: datapoints.VideoTypeJIT, num_samples: int, temporal_dim: int = -4
) -> datapoints.VideoTypeJIT:
if not torch.jit.is_scripting(): if not torch.jit.is_scripting():
_log_api_usage_once(uniform_temporal_subsample) _log_api_usage_once(uniform_temporal_subsample)
if torch.jit.is_scripting() or is_simple_tensor(inpt): if torch.jit.is_scripting() or is_simple_tensor(inpt):
return uniform_temporal_subsample_video(inpt, num_samples, temporal_dim=temporal_dim) return uniform_temporal_subsample_video(inpt, num_samples)
elif isinstance(inpt, datapoints.Video): elif isinstance(inpt, datapoints.Video):
if temporal_dim != -4 and inpt.ndim - 4 != temporal_dim: output = uniform_temporal_subsample_video(inpt.as_subclass(torch.Tensor), num_samples)
raise ValueError("Video inputs must have temporal_dim equivalent to -4")
output = uniform_temporal_subsample_video(
inpt.as_subclass(torch.Tensor), num_samples, temporal_dim=temporal_dim
)
return datapoints.Video.wrap_like(inpt, output) return datapoints.Video.wrap_like(inpt, output)
else: else:
raise TypeError(f"Input can either be a plain tensor or a `Video` datapoint, but got {type(inpt)} instead.") raise TypeError(f"Input can either be a plain tensor or a `Video` datapoint, but got {type(inpt)} instead.")
...@@ -3,7 +3,7 @@ from typing import Union ...@@ -3,7 +3,7 @@ from typing import Union
import numpy as np import numpy as np
import PIL.Image import PIL.Image
import torch import torch
from torchvision.prototype import datapoints from torchvision import datapoints
from torchvision.transforms import functional as _F from torchvision.transforms import functional as _F
......
from typing import Any from typing import Any
import torch import torch
from torchvision.prototype.datapoints._datapoint import Datapoint from torchvision.datapoints._datapoint import Datapoint
def is_simple_tensor(inpt: Any) -> bool: def is_simple_tensor(inpt: Any) -> bool:
......
...@@ -3,10 +3,10 @@ from __future__ import annotations ...@@ -3,10 +3,10 @@ from __future__ import annotations
from typing import Any, Callable, List, Tuple, Type, Union from typing import Any, Callable, List, Tuple, Type, Union
import PIL.Image import PIL.Image
from torchvision import datapoints
from torchvision._utils import sequence_to_str from torchvision._utils import sequence_to_str
from torchvision.prototype import datapoints from torchvision.transforms.v2.functional import get_dimensions, get_spatial_size, is_simple_tensor
from torchvision.prototype.transforms.functional import get_dimensions, get_spatial_size, is_simple_tensor
def query_bounding_box(flat_inputs: List[Any]) -> datapoints.BoundingBox: def query_bounding_box(flat_inputs: List[Any]) -> datapoints.BoundingBox:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment