Unverified Commit 29e0f66a authored by F-G Fernandez's avatar F-G Fernandez Committed by GitHub
Browse files

Added typing annotations to io/_video_opts (#4173)



* style: Added typing annotations

* style: Fixed lint

* style: Fixed typing

* chore: Updated mypy.ini

* style: Fixed typing

* chore: Updated mypy.ini

* style: Fixed typing compatibility with jit

* style: Fixed typing

* style: Fixed typing

* style: Fixed missing import

* style: Fixed typing of __iter__

* style: Fixed typing

* style: Fixed lint

* style: Finished typing

* style: ufmt the file

* style: Removed unnecessary typing

* style: Fixed typing of iterator
Co-authored-by: default avatarPhilip Meier <github.pmeier@posteo.de>
Co-authored-by: default avatarPrabhat Roy <prabhatroy@fb.com>
parent 34dc0701
......@@ -22,11 +22,11 @@ warn_unreachable = True
; miscellaneous strictness flags
allow_redefinition = True
[mypy-torchvision.io._video_opt.*]
[mypy-torchvision.io.image.*]
ignore_errors = True
[mypy-torchvision.io.*]
[mypy-torchvision.io.video.*]
ignore_errors = True
......
......@@ -132,7 +132,7 @@ class VideoReader:
raise StopIteration
return {"data": frame, "pts": pts}
def __iter__(self) -> Iterator["VideoReader"]:
def __iter__(self) -> Iterator[Dict[str, Any]]:
return self
def seek(self, time_s: float, keyframes_only: bool = False) -> "VideoReader":
......
import math
import warnings
from fractions import Fraction
from typing import List, Tuple
from typing import List, Tuple, Dict, Optional, Union
import torch
......@@ -26,10 +26,9 @@ class Timebase:
def __init__(
self,
numerator, # type: int
denominator, # type: int
):
# type: (...) -> None
numerator: int,
denominator: int,
) -> None:
self.numerator = numerator
self.denominator = denominator
......@@ -56,7 +55,7 @@ class VideoMetaData:
"audio_sample_rate",
]
def __init__(self):
def __init__(self) -> None:
self.has_video = False
self.video_timebase = Timebase(0, 1)
self.video_duration = 0.0
......@@ -67,8 +66,7 @@ class VideoMetaData:
self.audio_sample_rate = 0.0
def _validate_pts(pts_range):
# type: (List[int]) -> None
def _validate_pts(pts_range: Tuple[int, int]) -> None:
if pts_range[1] > 0:
assert (
......@@ -80,8 +78,14 @@ def _validate_pts(pts_range):
)
def _fill_info(vtimebase, vfps, vduration, atimebase, asample_rate, aduration):
# type: (torch.Tensor,torch.Tensor,torch.Tensor,torch.Tensor,torch.Tensor,torch.Tensor) -> VideoMetaData
def _fill_info(
vtimebase: torch.Tensor,
vfps: torch.Tensor,
vduration: torch.Tensor,
atimebase: torch.Tensor,
asample_rate: torch.Tensor,
aduration: torch.Tensor,
) -> VideoMetaData:
"""
Build update VideoMetaData struct with info about the video
"""
......@@ -106,8 +110,9 @@ def _fill_info(vtimebase, vfps, vduration, atimebase, asample_rate, aduration):
return meta
def _align_audio_frames(aframes, aframe_pts, audio_pts_range):
# type: (torch.Tensor, torch.Tensor, List[int]) -> torch.Tensor
def _align_audio_frames(
aframes: torch.Tensor, aframe_pts: torch.Tensor, audio_pts_range: Tuple[int, int]
) -> torch.Tensor:
start, end = aframe_pts[0], aframe_pts[-1]
num_samples = aframes.size(0)
step_per_aframe = float(end - start + 1) / float(num_samples)
......@@ -121,21 +126,21 @@ def _align_audio_frames(aframes, aframe_pts, audio_pts_range):
def _read_video_from_file(
filename,
seek_frame_margin=0.25,
read_video_stream=True,
video_width=0,
video_height=0,
video_min_dimension=0,
video_max_dimension=0,
video_pts_range=(0, -1),
video_timebase=default_timebase,
read_audio_stream=True,
audio_samples=0,
audio_channels=0,
audio_pts_range=(0, -1),
audio_timebase=default_timebase,
):
filename: str,
seek_frame_margin: float = 0.25,
read_video_stream: bool = True,
video_width: int = 0,
video_height: int = 0,
video_min_dimension: int = 0,
video_max_dimension: int = 0,
video_pts_range: Tuple[int, int] = (0, -1),
video_timebase: Fraction = default_timebase,
read_audio_stream: bool = True,
audio_samples: int = 0,
audio_channels: int = 0,
audio_pts_range: Tuple[int, int] = (0, -1),
audio_timebase: Fraction = default_timebase,
) -> Tuple[torch.Tensor, torch.Tensor, VideoMetaData]:
"""
Reads a video from a file, returning both the video frames as well as
the audio frames
......@@ -217,7 +222,7 @@ def _read_video_from_file(
return vframes, aframes, info
def _read_video_timestamps_from_file(filename):
def _read_video_timestamps_from_file(filename: str) -> Tuple[List[int], List[int], VideoMetaData]:
"""
Decode all video- and audio frames in the video. Only pts
(presentation timestamp) is returned. The actual frame pixel data is not
......@@ -252,7 +257,7 @@ def _read_video_timestamps_from_file(filename):
return vframe_pts, aframe_pts, info
def _probe_video_from_file(filename):
def _probe_video_from_file(filename: str) -> VideoMetaData:
"""
Probe a video file and return VideoMetaData with info about the video
"""
......@@ -263,24 +268,23 @@ def _probe_video_from_file(filename):
def _read_video_from_memory(
video_data, # type: torch.Tensor
seek_frame_margin=0.25, # type: float
read_video_stream=1, # type: int
video_width=0, # type: int
video_height=0, # type: int
video_min_dimension=0, # type: int
video_max_dimension=0, # type: int
video_pts_range=(0, -1), # type: List[int]
video_timebase_numerator=0, # type: int
video_timebase_denominator=1, # type: int
read_audio_stream=1, # type: int
audio_samples=0, # type: int
audio_channels=0, # type: int
audio_pts_range=(0, -1), # type: List[int]
audio_timebase_numerator=0, # type: int
audio_timebase_denominator=1, # type: int
):
# type: (...) -> Tuple[torch.Tensor, torch.Tensor]
video_data: torch.Tensor,
seek_frame_margin: float = 0.25,
read_video_stream: int = 1,
video_width: int = 0,
video_height: int = 0,
video_min_dimension: int = 0,
video_max_dimension: int = 0,
video_pts_range: Tuple[int, int] = (0, -1),
video_timebase_numerator: int = 0,
video_timebase_denominator: int = 1,
read_audio_stream: int = 1,
audio_samples: int = 0,
audio_channels: int = 0,
audio_pts_range: Tuple[int, int] = (0, -1),
audio_timebase_numerator: int = 0,
audio_timebase_denominator: int = 1,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Reads a video from memory, returning both the video frames as well as
the audio frames
......@@ -370,7 +374,9 @@ def _read_video_from_memory(
return vframes, aframes
def _read_video_timestamps_from_memory(video_data):
def _read_video_timestamps_from_memory(
video_data: torch.Tensor,
) -> Tuple[List[int], List[int], VideoMetaData]:
"""
Decode all frames in the video. Only pts (presentation timestamp) is returned.
The actual frame pixel data is not copied. Thus, read_video_timestamps(...)
......@@ -407,8 +413,9 @@ def _read_video_timestamps_from_memory(video_data):
return vframe_pts, aframe_pts, info
def _probe_video_from_memory(video_data):
# type: (torch.Tensor) -> VideoMetaData
def _probe_video_from_memory(
video_data: torch.Tensor,
) -> VideoMetaData:
"""
Probe a video in memory and return VideoMetaData with info about the video
This function is torchscriptable
......@@ -421,7 +428,9 @@ def _probe_video_from_memory(video_data):
return info
def _convert_to_sec(start_pts, end_pts, pts_unit, time_base):
def _convert_to_sec(
start_pts: Union[float, Fraction], end_pts: Union[float, Fraction], pts_unit: str, time_base: Fraction
) -> Tuple[Union[float, Fraction], Union[float, Fraction], str]:
if pts_unit == "pts":
start_pts = float(start_pts * time_base)
end_pts = float(end_pts * time_base)
......@@ -429,7 +438,12 @@ def _convert_to_sec(start_pts, end_pts, pts_unit, time_base):
return start_pts, end_pts, pts_unit
def _read_video(filename, start_pts=0, end_pts=None, pts_unit="pts"):
def _read_video(
filename: str,
start_pts: Union[float, Fraction] = 0,
end_pts: Optional[Union[float, Fraction]] = None,
pts_unit: str = "pts",
) -> Tuple[torch.Tensor, torch.Tensor, Dict[str, float]]:
if end_pts is None:
end_pts = float("inf")
......@@ -495,13 +509,16 @@ def _read_video(filename, start_pts=0, end_pts=None, pts_unit="pts"):
return vframes, aframes, _info
def _read_video_timestamps(filename, pts_unit="pts"):
def _read_video_timestamps(
filename: str, pts_unit: str = "pts"
) -> Tuple[Union[List[int], List[Fraction]], Optional[float]]:
if pts_unit == "pts":
warnings.warn(
"The pts_unit 'pts' gives wrong results and will be removed in a "
+ "follow-up version. Please use pts_unit 'sec'."
)
pts: Union[List[int], List[Fraction]]
pts, _, info = _read_video_timestamps_from_file(filename)
if pts_unit == "sec":
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment