Unverified Commit 999ef255 authored by F-G Fernandez's avatar F-G Fernandez Committed by GitHub
Browse files

Added missing typing annotations in datasets/video_utils (#4172)



* style: Fixed last missing typing annotation

* style: Fixed typing

* style: Fixed remaining typing annotations

* style: Fixed typing

* style: Fixed typing

* refactor: Removed unused import

* Update torchvision/datasets/video_utils.py
Co-authored-by: default avatarPhilip Meier <github.pmeier@posteo.de>
Co-authored-by: default avatarPrabhat Roy <prabhatroy@fb.com>
parent ef2e4187
......@@ -2,7 +2,7 @@ import bisect
import math
import warnings
from fractions import Fraction
from typing import List
from typing import Any, Dict, List, Optional, Callable, Union, Tuple, TypeVar, cast
import torch
from torchvision.io import (
......@@ -14,8 +14,10 @@ from torchvision.io import (
from .utils import tqdm
T = TypeVar("T")
def pts_convert(pts, timebase_from, timebase_to, round_func=math.floor):
def pts_convert(pts: int, timebase_from: Fraction, timebase_to: Fraction, round_func: Callable = math.floor) -> int:
"""convert pts between different time bases
Args:
pts: presentation timestamp, float
......@@ -27,7 +29,7 @@ def pts_convert(pts, timebase_from, timebase_to, round_func=math.floor):
return round_func(new_pts)
def unfold(tensor, size, step, dilation=1):
def unfold(tensor: torch.Tensor, size: int, step: int, dilation: int = 1) -> torch.Tensor:
"""
similar to tensor.unfold, but with the dilation
and specialized for 1d tensors
......@@ -55,17 +57,17 @@ class _VideoTimestampsDataset:
pickled when forking.
"""
def __init__(self, video_paths: List[str]):
def __init__(self, video_paths: List[str]) -> None:
self.video_paths = video_paths
def __len__(self):
def __len__(self) -> int:
return len(self.video_paths)
def __getitem__(self, idx):
def __getitem__(self, idx: int) -> Tuple[List[int], Optional[float]]:
return read_video_timestamps(self.video_paths[idx])
def _collate_fn(x):
def _collate_fn(x: T) -> T:
"""
Dummy collate function to be used with _VideoTimestampsDataset
"""
......@@ -100,19 +102,19 @@ class VideoClips:
def __init__(
self,
video_paths,
clip_length_in_frames=16,
frames_between_clips=1,
frame_rate=None,
_precomputed_metadata=None,
num_workers=0,
_video_width=0,
_video_height=0,
_video_min_dimension=0,
_video_max_dimension=0,
_audio_samples=0,
_audio_channels=0,
):
video_paths: List[str],
clip_length_in_frames: int = 16,
frames_between_clips: int = 1,
frame_rate: Optional[int] = None,
_precomputed_metadata: Optional[Dict[str, Any]] = None,
num_workers: int = 0,
_video_width: int = 0,
_video_height: int = 0,
_video_min_dimension: int = 0,
_video_max_dimension: int = 0,
_audio_samples: int = 0,
_audio_channels: int = 0,
) -> None:
self.video_paths = video_paths
self.num_workers = num_workers
......@@ -131,7 +133,7 @@ class VideoClips:
self._init_from_metadata(_precomputed_metadata)
self.compute_clips(clip_length_in_frames, frames_between_clips, frame_rate)
def _compute_frame_pts(self):
def _compute_frame_pts(self) -> None:
self.video_pts = []
self.video_fps = []
......@@ -139,8 +141,8 @@ class VideoClips:
# so need to create a dummy dataset first
import torch.utils.data
dl = torch.utils.data.DataLoader(
_VideoTimestampsDataset(self.video_paths),
dl: torch.utils.data.DataLoader = torch.utils.data.DataLoader(
_VideoTimestampsDataset(self.video_paths), # type: ignore[arg-type]
batch_size=16,
num_workers=self.num_workers,
collate_fn=_collate_fn,
......@@ -157,7 +159,7 @@ class VideoClips:
self.video_pts.extend(clips)
self.video_fps.extend(fps)
def _init_from_metadata(self, metadata):
def _init_from_metadata(self, metadata: Dict[str, Any]) -> None:
self.video_paths = metadata["video_paths"]
assert len(self.video_paths) == len(metadata["video_pts"])
self.video_pts = metadata["video_pts"]
......@@ -165,7 +167,7 @@ class VideoClips:
self.video_fps = metadata["video_fps"]
@property
def metadata(self):
def metadata(self) -> Dict[str, Any]:
_metadata = {
"video_paths": self.video_paths,
"video_pts": self.video_pts,
......@@ -173,7 +175,7 @@ class VideoClips:
}
return _metadata
def subset(self, indices):
def subset(self, indices: List[int]) -> "VideoClips":
video_paths = [self.video_paths[i] for i in indices]
video_pts = [self.video_pts[i] for i in indices]
video_fps = [self.video_fps[i] for i in indices]
......@@ -198,7 +200,9 @@ class VideoClips:
)
@staticmethod
def compute_clips_for_video(video_pts, num_frames, step, fps, frame_rate):
def compute_clips_for_video(
video_pts: torch.Tensor, num_frames: int, step: int, fps: int, frame_rate: Optional[int] = None
) -> Tuple[torch.Tensor, Union[List[slice], torch.Tensor]]:
if fps is None:
# if for some reason the video doesn't have fps (because doesn't have a video stream)
# set the fps to 1. The value doesn't matter, because video_pts is empty anyway
......@@ -206,21 +210,22 @@ class VideoClips:
if frame_rate is None:
frame_rate = fps
total_frames = len(video_pts) * (float(frame_rate) / fps)
idxs = VideoClips._resample_video_idx(int(math.floor(total_frames)), fps, frame_rate)
video_pts = video_pts[idxs]
_idxs = VideoClips._resample_video_idx(int(math.floor(total_frames)), fps, frame_rate)
video_pts = video_pts[_idxs]
clips = unfold(video_pts, num_frames, step)
if not clips.numel():
warnings.warn(
"There aren't enough frames in the current video to get a clip for the given clip length and "
"frames between clips. The video (and potentially others) will be skipped."
)
if isinstance(idxs, slice):
idxs = [idxs] * len(clips)
idxs: Union[List[slice], torch.Tensor]
if isinstance(_idxs, slice):
idxs = [_idxs] * len(clips)
else:
idxs = unfold(idxs, num_frames, step)
idxs = unfold(_idxs, num_frames, step)
return clips, idxs
def compute_clips(self, num_frames, step, frame_rate=None):
def compute_clips(self, num_frames: int, step: int, frame_rate: Optional[int] = None) -> None:
"""
Compute all consecutive sequences of clips from video_pts.
Always returns clips of size `num_frames`, meaning that the
......@@ -243,19 +248,19 @@ class VideoClips:
clip_lengths = torch.as_tensor([len(v) for v in self.clips])
self.cumulative_sizes = clip_lengths.cumsum(0).tolist()
def __len__(self):
def __len__(self) -> int:
return self.num_clips()
def num_videos(self):
def num_videos(self) -> int:
return len(self.video_paths)
def num_clips(self):
def num_clips(self) -> int:
"""
Number of subclips that are available in the video list.
"""
return self.cumulative_sizes[-1]
def get_clip_location(self, idx):
def get_clip_location(self, idx: int) -> Tuple[int, int]:
"""
Converts a flattened representation of the indices into a video_idx, clip_idx
representation.
......@@ -268,7 +273,7 @@ class VideoClips:
return video_idx, clip_idx
@staticmethod
def _resample_video_idx(num_frames, original_fps, new_fps):
def _resample_video_idx(num_frames: int, original_fps: int, new_fps: int) -> Union[slice, torch.Tensor]:
step = float(original_fps) / new_fps
if step.is_integer():
# optimization: if step is integer, don't need to perform
......@@ -279,7 +284,7 @@ class VideoClips:
idxs = idxs.floor().to(torch.int64)
return idxs
def get_clip(self, idx):
def get_clip(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor, Dict[str, Any], int]:
"""
Gets a subclip from a list of videos.
......@@ -320,22 +325,22 @@ class VideoClips:
end_pts = clip_pts[-1].item()
video, audio, info = read_video(video_path, start_pts, end_pts)
else:
info = _probe_video_from_file(video_path)
video_fps = info.video_fps
_info = _probe_video_from_file(video_path)
video_fps = _info.video_fps
audio_fps = None
video_start_pts = clip_pts[0].item()
video_end_pts = clip_pts[-1].item()
video_start_pts = cast(int, clip_pts[0].item())
video_end_pts = cast(int, clip_pts[-1].item())
audio_start_pts, audio_end_pts = 0, -1
audio_timebase = Fraction(0, 1)
video_timebase = Fraction(info.video_timebase.numerator, info.video_timebase.denominator)
if info.has_audio:
audio_timebase = Fraction(info.audio_timebase.numerator, info.audio_timebase.denominator)
video_timebase = Fraction(_info.video_timebase.numerator, _info.video_timebase.denominator)
if _info.has_audio:
audio_timebase = Fraction(_info.audio_timebase.numerator, _info.audio_timebase.denominator)
audio_start_pts = pts_convert(video_start_pts, video_timebase, audio_timebase, math.floor)
audio_end_pts = pts_convert(video_end_pts, video_timebase, audio_timebase, math.ceil)
audio_fps = info.audio_sample_rate
video, audio, info = _read_video_from_file(
audio_fps = _info.audio_sample_rate
video, audio, _ = _read_video_from_file(
video_path,
video_width=self._video_width,
video_height=self._video_height,
......@@ -362,7 +367,7 @@ class VideoClips:
assert len(video) == self.num_frames, f"{video.shape} x {self.num_frames}"
return video, audio, info, video_idx
def __getstate__(self):
def __getstate__(self) -> Dict[str, Any]:
video_pts_sizes = [len(v) for v in self.video_pts]
# To be back-compatible, we convert data to dtype torch.long as needed
# because for empty list, in legacy implementation, torch.as_tensor will
......@@ -371,10 +376,10 @@ class VideoClips:
video_pts = [x.to(torch.int64) for x in self.video_pts]
# video_pts can be an empty list if no frames have been decoded
if video_pts:
video_pts = torch.cat(video_pts)
video_pts = torch.cat(video_pts) # type: ignore[assignment]
# avoid bug in https://github.com/pytorch/pytorch/issues/32351
# TODO: Revert it once the bug is fixed.
video_pts = video_pts.numpy()
video_pts = video_pts.numpy() # type: ignore[attr-defined]
# make a copy of the fields of self
d = self.__dict__.copy()
......@@ -390,7 +395,7 @@ class VideoClips:
d["_version"] = 2
return d
def __setstate__(self, d):
def __setstate__(self, d: Dict[str, Any]) -> None:
# for backwards-compatibility
if "_version" not in d:
self.__dict__ = d
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment