Unverified Commit 999ef255 authored by F-G Fernandez's avatar F-G Fernandez Committed by GitHub
Browse files

Added missing typing annotations in datasets/video_utils (#4172)



* style: Fixed last missing typing annotation

* style: Fixed typing

* style: Fixed remaining typing annotations

* style: Fixed typing

* style: Fixed typing

* refactor: Removed unused import

* Update torchvision/datasets/video_utils.py
Co-authored-by: default avatarPhilip Meier <github.pmeier@posteo.de>
Co-authored-by: default avatarPrabhat Roy <prabhatroy@fb.com>
parent ef2e4187
...@@ -2,7 +2,7 @@ import bisect ...@@ -2,7 +2,7 @@ import bisect
import math import math
import warnings import warnings
from fractions import Fraction from fractions import Fraction
from typing import List from typing import Any, Dict, List, Optional, Callable, Union, Tuple, TypeVar, cast
import torch import torch
from torchvision.io import ( from torchvision.io import (
...@@ -14,8 +14,10 @@ from torchvision.io import ( ...@@ -14,8 +14,10 @@ from torchvision.io import (
from .utils import tqdm from .utils import tqdm
T = TypeVar("T")
def pts_convert(pts, timebase_from, timebase_to, round_func=math.floor):
def pts_convert(pts: int, timebase_from: Fraction, timebase_to: Fraction, round_func: Callable = math.floor) -> int:
"""convert pts between different time bases """convert pts between different time bases
Args: Args:
pts: presentation timestamp, float pts: presentation timestamp, float
...@@ -27,7 +29,7 @@ def pts_convert(pts, timebase_from, timebase_to, round_func=math.floor): ...@@ -27,7 +29,7 @@ def pts_convert(pts, timebase_from, timebase_to, round_func=math.floor):
return round_func(new_pts) return round_func(new_pts)
def unfold(tensor, size, step, dilation=1): def unfold(tensor: torch.Tensor, size: int, step: int, dilation: int = 1) -> torch.Tensor:
""" """
similar to tensor.unfold, but with the dilation similar to tensor.unfold, but with the dilation
and specialized for 1d tensors and specialized for 1d tensors
...@@ -55,17 +57,17 @@ class _VideoTimestampsDataset: ...@@ -55,17 +57,17 @@ class _VideoTimestampsDataset:
pickled when forking. pickled when forking.
""" """
def __init__(self, video_paths: List[str]): def __init__(self, video_paths: List[str]) -> None:
self.video_paths = video_paths self.video_paths = video_paths
def __len__(self): def __len__(self) -> int:
return len(self.video_paths) return len(self.video_paths)
def __getitem__(self, idx): def __getitem__(self, idx: int) -> Tuple[List[int], Optional[float]]:
return read_video_timestamps(self.video_paths[idx]) return read_video_timestamps(self.video_paths[idx])
def _collate_fn(x): def _collate_fn(x: T) -> T:
""" """
Dummy collate function to be used with _VideoTimestampsDataset Dummy collate function to be used with _VideoTimestampsDataset
""" """
...@@ -100,19 +102,19 @@ class VideoClips: ...@@ -100,19 +102,19 @@ class VideoClips:
def __init__( def __init__(
self, self,
video_paths, video_paths: List[str],
clip_length_in_frames=16, clip_length_in_frames: int = 16,
frames_between_clips=1, frames_between_clips: int = 1,
frame_rate=None, frame_rate: Optional[int] = None,
_precomputed_metadata=None, _precomputed_metadata: Optional[Dict[str, Any]] = None,
num_workers=0, num_workers: int = 0,
_video_width=0, _video_width: int = 0,
_video_height=0, _video_height: int = 0,
_video_min_dimension=0, _video_min_dimension: int = 0,
_video_max_dimension=0, _video_max_dimension: int = 0,
_audio_samples=0, _audio_samples: int = 0,
_audio_channels=0, _audio_channels: int = 0,
): ) -> None:
self.video_paths = video_paths self.video_paths = video_paths
self.num_workers = num_workers self.num_workers = num_workers
...@@ -131,7 +133,7 @@ class VideoClips: ...@@ -131,7 +133,7 @@ class VideoClips:
self._init_from_metadata(_precomputed_metadata) self._init_from_metadata(_precomputed_metadata)
self.compute_clips(clip_length_in_frames, frames_between_clips, frame_rate) self.compute_clips(clip_length_in_frames, frames_between_clips, frame_rate)
def _compute_frame_pts(self): def _compute_frame_pts(self) -> None:
self.video_pts = [] self.video_pts = []
self.video_fps = [] self.video_fps = []
...@@ -139,8 +141,8 @@ class VideoClips: ...@@ -139,8 +141,8 @@ class VideoClips:
# so need to create a dummy dataset first # so need to create a dummy dataset first
import torch.utils.data import torch.utils.data
dl = torch.utils.data.DataLoader( dl: torch.utils.data.DataLoader = torch.utils.data.DataLoader(
_VideoTimestampsDataset(self.video_paths), _VideoTimestampsDataset(self.video_paths), # type: ignore[arg-type]
batch_size=16, batch_size=16,
num_workers=self.num_workers, num_workers=self.num_workers,
collate_fn=_collate_fn, collate_fn=_collate_fn,
...@@ -157,7 +159,7 @@ class VideoClips: ...@@ -157,7 +159,7 @@ class VideoClips:
self.video_pts.extend(clips) self.video_pts.extend(clips)
self.video_fps.extend(fps) self.video_fps.extend(fps)
def _init_from_metadata(self, metadata): def _init_from_metadata(self, metadata: Dict[str, Any]) -> None:
self.video_paths = metadata["video_paths"] self.video_paths = metadata["video_paths"]
assert len(self.video_paths) == len(metadata["video_pts"]) assert len(self.video_paths) == len(metadata["video_pts"])
self.video_pts = metadata["video_pts"] self.video_pts = metadata["video_pts"]
...@@ -165,7 +167,7 @@ class VideoClips: ...@@ -165,7 +167,7 @@ class VideoClips:
self.video_fps = metadata["video_fps"] self.video_fps = metadata["video_fps"]
@property @property
def metadata(self): def metadata(self) -> Dict[str, Any]:
_metadata = { _metadata = {
"video_paths": self.video_paths, "video_paths": self.video_paths,
"video_pts": self.video_pts, "video_pts": self.video_pts,
...@@ -173,7 +175,7 @@ class VideoClips: ...@@ -173,7 +175,7 @@ class VideoClips:
} }
return _metadata return _metadata
def subset(self, indices): def subset(self, indices: List[int]) -> "VideoClips":
video_paths = [self.video_paths[i] for i in indices] video_paths = [self.video_paths[i] for i in indices]
video_pts = [self.video_pts[i] for i in indices] video_pts = [self.video_pts[i] for i in indices]
video_fps = [self.video_fps[i] for i in indices] video_fps = [self.video_fps[i] for i in indices]
...@@ -198,7 +200,9 @@ class VideoClips: ...@@ -198,7 +200,9 @@ class VideoClips:
) )
@staticmethod @staticmethod
def compute_clips_for_video(video_pts, num_frames, step, fps, frame_rate): def compute_clips_for_video(
video_pts: torch.Tensor, num_frames: int, step: int, fps: int, frame_rate: Optional[int] = None
) -> Tuple[torch.Tensor, Union[List[slice], torch.Tensor]]:
if fps is None: if fps is None:
# if for some reason the video doesn't have fps (because doesn't have a video stream) # if for some reason the video doesn't have fps (because doesn't have a video stream)
# set the fps to 1. The value doesn't matter, because video_pts is empty anyway # set the fps to 1. The value doesn't matter, because video_pts is empty anyway
...@@ -206,21 +210,22 @@ class VideoClips: ...@@ -206,21 +210,22 @@ class VideoClips:
if frame_rate is None: if frame_rate is None:
frame_rate = fps frame_rate = fps
total_frames = len(video_pts) * (float(frame_rate) / fps) total_frames = len(video_pts) * (float(frame_rate) / fps)
idxs = VideoClips._resample_video_idx(int(math.floor(total_frames)), fps, frame_rate) _idxs = VideoClips._resample_video_idx(int(math.floor(total_frames)), fps, frame_rate)
video_pts = video_pts[idxs] video_pts = video_pts[_idxs]
clips = unfold(video_pts, num_frames, step) clips = unfold(video_pts, num_frames, step)
if not clips.numel(): if not clips.numel():
warnings.warn( warnings.warn(
"There aren't enough frames in the current video to get a clip for the given clip length and " "There aren't enough frames in the current video to get a clip for the given clip length and "
"frames between clips. The video (and potentially others) will be skipped." "frames between clips. The video (and potentially others) will be skipped."
) )
if isinstance(idxs, slice): idxs: Union[List[slice], torch.Tensor]
idxs = [idxs] * len(clips) if isinstance(_idxs, slice):
idxs = [_idxs] * len(clips)
else: else:
idxs = unfold(idxs, num_frames, step) idxs = unfold(_idxs, num_frames, step)
return clips, idxs return clips, idxs
def compute_clips(self, num_frames, step, frame_rate=None): def compute_clips(self, num_frames: int, step: int, frame_rate: Optional[int] = None) -> None:
""" """
Compute all consecutive sequences of clips from video_pts. Compute all consecutive sequences of clips from video_pts.
Always returns clips of size `num_frames`, meaning that the Always returns clips of size `num_frames`, meaning that the
...@@ -243,19 +248,19 @@ class VideoClips: ...@@ -243,19 +248,19 @@ class VideoClips:
clip_lengths = torch.as_tensor([len(v) for v in self.clips]) clip_lengths = torch.as_tensor([len(v) for v in self.clips])
self.cumulative_sizes = clip_lengths.cumsum(0).tolist() self.cumulative_sizes = clip_lengths.cumsum(0).tolist()
def __len__(self): def __len__(self) -> int:
return self.num_clips() return self.num_clips()
def num_videos(self): def num_videos(self) -> int:
return len(self.video_paths) return len(self.video_paths)
def num_clips(self): def num_clips(self) -> int:
""" """
Number of subclips that are available in the video list. Number of subclips that are available in the video list.
""" """
return self.cumulative_sizes[-1] return self.cumulative_sizes[-1]
def get_clip_location(self, idx): def get_clip_location(self, idx: int) -> Tuple[int, int]:
""" """
Converts a flattened representation of the indices into a video_idx, clip_idx Converts a flattened representation of the indices into a video_idx, clip_idx
representation. representation.
...@@ -268,7 +273,7 @@ class VideoClips: ...@@ -268,7 +273,7 @@ class VideoClips:
return video_idx, clip_idx return video_idx, clip_idx
@staticmethod @staticmethod
def _resample_video_idx(num_frames, original_fps, new_fps): def _resample_video_idx(num_frames: int, original_fps: int, new_fps: int) -> Union[slice, torch.Tensor]:
step = float(original_fps) / new_fps step = float(original_fps) / new_fps
if step.is_integer(): if step.is_integer():
# optimization: if step is integer, don't need to perform # optimization: if step is integer, don't need to perform
...@@ -279,7 +284,7 @@ class VideoClips: ...@@ -279,7 +284,7 @@ class VideoClips:
idxs = idxs.floor().to(torch.int64) idxs = idxs.floor().to(torch.int64)
return idxs return idxs
def get_clip(self, idx): def get_clip(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor, Dict[str, Any], int]:
""" """
Gets a subclip from a list of videos. Gets a subclip from a list of videos.
...@@ -320,22 +325,22 @@ class VideoClips: ...@@ -320,22 +325,22 @@ class VideoClips:
end_pts = clip_pts[-1].item() end_pts = clip_pts[-1].item()
video, audio, info = read_video(video_path, start_pts, end_pts) video, audio, info = read_video(video_path, start_pts, end_pts)
else: else:
info = _probe_video_from_file(video_path) _info = _probe_video_from_file(video_path)
video_fps = info.video_fps video_fps = _info.video_fps
audio_fps = None audio_fps = None
video_start_pts = clip_pts[0].item() video_start_pts = cast(int, clip_pts[0].item())
video_end_pts = clip_pts[-1].item() video_end_pts = cast(int, clip_pts[-1].item())
audio_start_pts, audio_end_pts = 0, -1 audio_start_pts, audio_end_pts = 0, -1
audio_timebase = Fraction(0, 1) audio_timebase = Fraction(0, 1)
video_timebase = Fraction(info.video_timebase.numerator, info.video_timebase.denominator) video_timebase = Fraction(_info.video_timebase.numerator, _info.video_timebase.denominator)
if info.has_audio: if _info.has_audio:
audio_timebase = Fraction(info.audio_timebase.numerator, info.audio_timebase.denominator) audio_timebase = Fraction(_info.audio_timebase.numerator, _info.audio_timebase.denominator)
audio_start_pts = pts_convert(video_start_pts, video_timebase, audio_timebase, math.floor) audio_start_pts = pts_convert(video_start_pts, video_timebase, audio_timebase, math.floor)
audio_end_pts = pts_convert(video_end_pts, video_timebase, audio_timebase, math.ceil) audio_end_pts = pts_convert(video_end_pts, video_timebase, audio_timebase, math.ceil)
audio_fps = info.audio_sample_rate audio_fps = _info.audio_sample_rate
video, audio, info = _read_video_from_file( video, audio, _ = _read_video_from_file(
video_path, video_path,
video_width=self._video_width, video_width=self._video_width,
video_height=self._video_height, video_height=self._video_height,
...@@ -362,7 +367,7 @@ class VideoClips: ...@@ -362,7 +367,7 @@ class VideoClips:
assert len(video) == self.num_frames, f"{video.shape} x {self.num_frames}" assert len(video) == self.num_frames, f"{video.shape} x {self.num_frames}"
return video, audio, info, video_idx return video, audio, info, video_idx
def __getstate__(self): def __getstate__(self) -> Dict[str, Any]:
video_pts_sizes = [len(v) for v in self.video_pts] video_pts_sizes = [len(v) for v in self.video_pts]
# To be back-compatible, we convert data to dtype torch.long as needed # To be back-compatible, we convert data to dtype torch.long as needed
# because for empty list, in legacy implementation, torch.as_tensor will # because for empty list, in legacy implementation, torch.as_tensor will
...@@ -371,10 +376,10 @@ class VideoClips: ...@@ -371,10 +376,10 @@ class VideoClips:
video_pts = [x.to(torch.int64) for x in self.video_pts] video_pts = [x.to(torch.int64) for x in self.video_pts]
# video_pts can be an empty list if no frames have been decoded # video_pts can be an empty list if no frames have been decoded
if video_pts: if video_pts:
video_pts = torch.cat(video_pts) video_pts = torch.cat(video_pts) # type: ignore[assignment]
# avoid bug in https://github.com/pytorch/pytorch/issues/32351 # avoid bug in https://github.com/pytorch/pytorch/issues/32351
# TODO: Revert it once the bug is fixed. # TODO: Revert it once the bug is fixed.
video_pts = video_pts.numpy() video_pts = video_pts.numpy() # type: ignore[attr-defined]
# make a copy of the fields of self # make a copy of the fields of self
d = self.__dict__.copy() d = self.__dict__.copy()
...@@ -390,7 +395,7 @@ class VideoClips: ...@@ -390,7 +395,7 @@ class VideoClips:
d["_version"] = 2 d["_version"] = 2
return d return d
def __setstate__(self, d): def __setstate__(self, d: Dict[str, Any]) -> None:
# for backwards-compatibility # for backwards-compatibility
if "_version" not in d: if "_version" not in d:
self.__dict__ = d self.__dict__ = d
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment