"src/vscode:/vscode.git/clone" did not exist on "2715079344b725bdb045f601551dae02509e393e"
Unverified Commit 29e0f66a authored by F-G Fernandez's avatar F-G Fernandez Committed by GitHub
Browse files

Added typing annotations to io/_video_opts (#4173)



* style: Added typing annotations

* style: Fixed lint

* style: Fixed typing

* chore: Updated mypy.ini

* style: Fixed typing

* chore: Updated mypy.ini

* style: Fixed typing compatibility with jit

* style: Fixed typing

* style: Fixed typing

* style: Fixed missing import

* style: Fixed typing of __iter__

* style: Fixed typing

* style: Fixed lint

* style: Finished typing

* style: ufmt the file

* style: Removed unnecessary typing

* style: Fixed typing of iterator
Co-authored-by: default avatarPhilip Meier <github.pmeier@posteo.de>
Co-authored-by: default avatarPrabhat Roy <prabhatroy@fb.com>
parent 34dc0701
...@@ -22,11 +22,11 @@ warn_unreachable = True ...@@ -22,11 +22,11 @@ warn_unreachable = True
; miscellaneous strictness flags ; miscellaneous strictness flags
allow_redefinition = True allow_redefinition = True
[mypy-torchvision.io._video_opt.*] [mypy-torchvision.io.image.*]
ignore_errors = True ignore_errors = True
[mypy-torchvision.io.*] [mypy-torchvision.io.video.*]
ignore_errors = True ignore_errors = True
......
...@@ -132,7 +132,7 @@ class VideoReader: ...@@ -132,7 +132,7 @@ class VideoReader:
raise StopIteration raise StopIteration
return {"data": frame, "pts": pts} return {"data": frame, "pts": pts}
def __iter__(self) -> Iterator["VideoReader"]: def __iter__(self) -> Iterator[Dict[str, Any]]:
return self return self
def seek(self, time_s: float, keyframes_only: bool = False) -> "VideoReader": def seek(self, time_s: float, keyframes_only: bool = False) -> "VideoReader":
......
import math import math
import warnings import warnings
from fractions import Fraction from fractions import Fraction
from typing import List, Tuple from typing import List, Tuple, Dict, Optional, Union
import torch import torch
...@@ -26,10 +26,9 @@ class Timebase: ...@@ -26,10 +26,9 @@ class Timebase:
def __init__( def __init__(
self, self,
numerator, # type: int numerator: int,
denominator, # type: int denominator: int,
): ) -> None:
# type: (...) -> None
self.numerator = numerator self.numerator = numerator
self.denominator = denominator self.denominator = denominator
...@@ -56,7 +55,7 @@ class VideoMetaData: ...@@ -56,7 +55,7 @@ class VideoMetaData:
"audio_sample_rate", "audio_sample_rate",
] ]
def __init__(self): def __init__(self) -> None:
self.has_video = False self.has_video = False
self.video_timebase = Timebase(0, 1) self.video_timebase = Timebase(0, 1)
self.video_duration = 0.0 self.video_duration = 0.0
...@@ -67,8 +66,7 @@ class VideoMetaData: ...@@ -67,8 +66,7 @@ class VideoMetaData:
self.audio_sample_rate = 0.0 self.audio_sample_rate = 0.0
def _validate_pts(pts_range): def _validate_pts(pts_range: Tuple[int, int]) -> None:
# type: (List[int]) -> None
if pts_range[1] > 0: if pts_range[1] > 0:
assert ( assert (
...@@ -80,8 +78,14 @@ def _validate_pts(pts_range): ...@@ -80,8 +78,14 @@ def _validate_pts(pts_range):
) )
def _fill_info(vtimebase, vfps, vduration, atimebase, asample_rate, aduration): def _fill_info(
# type: (torch.Tensor,torch.Tensor,torch.Tensor,torch.Tensor,torch.Tensor,torch.Tensor) -> VideoMetaData vtimebase: torch.Tensor,
vfps: torch.Tensor,
vduration: torch.Tensor,
atimebase: torch.Tensor,
asample_rate: torch.Tensor,
aduration: torch.Tensor,
) -> VideoMetaData:
""" """
Build update VideoMetaData struct with info about the video Build update VideoMetaData struct with info about the video
""" """
...@@ -106,8 +110,9 @@ def _fill_info(vtimebase, vfps, vduration, atimebase, asample_rate, aduration): ...@@ -106,8 +110,9 @@ def _fill_info(vtimebase, vfps, vduration, atimebase, asample_rate, aduration):
return meta return meta
def _align_audio_frames(aframes, aframe_pts, audio_pts_range): def _align_audio_frames(
# type: (torch.Tensor, torch.Tensor, List[int]) -> torch.Tensor aframes: torch.Tensor, aframe_pts: torch.Tensor, audio_pts_range: Tuple[int, int]
) -> torch.Tensor:
start, end = aframe_pts[0], aframe_pts[-1] start, end = aframe_pts[0], aframe_pts[-1]
num_samples = aframes.size(0) num_samples = aframes.size(0)
step_per_aframe = float(end - start + 1) / float(num_samples) step_per_aframe = float(end - start + 1) / float(num_samples)
...@@ -121,21 +126,21 @@ def _align_audio_frames(aframes, aframe_pts, audio_pts_range): ...@@ -121,21 +126,21 @@ def _align_audio_frames(aframes, aframe_pts, audio_pts_range):
def _read_video_from_file( def _read_video_from_file(
filename, filename: str,
seek_frame_margin=0.25, seek_frame_margin: float = 0.25,
read_video_stream=True, read_video_stream: bool = True,
video_width=0, video_width: int = 0,
video_height=0, video_height: int = 0,
video_min_dimension=0, video_min_dimension: int = 0,
video_max_dimension=0, video_max_dimension: int = 0,
video_pts_range=(0, -1), video_pts_range: Tuple[int, int] = (0, -1),
video_timebase=default_timebase, video_timebase: Fraction = default_timebase,
read_audio_stream=True, read_audio_stream: bool = True,
audio_samples=0, audio_samples: int = 0,
audio_channels=0, audio_channels: int = 0,
audio_pts_range=(0, -1), audio_pts_range: Tuple[int, int] = (0, -1),
audio_timebase=default_timebase, audio_timebase: Fraction = default_timebase,
): ) -> Tuple[torch.Tensor, torch.Tensor, VideoMetaData]:
""" """
Reads a video from a file, returning both the video frames as well as Reads a video from a file, returning both the video frames as well as
the audio frames the audio frames
...@@ -217,7 +222,7 @@ def _read_video_from_file( ...@@ -217,7 +222,7 @@ def _read_video_from_file(
return vframes, aframes, info return vframes, aframes, info
def _read_video_timestamps_from_file(filename): def _read_video_timestamps_from_file(filename: str) -> Tuple[List[int], List[int], VideoMetaData]:
""" """
Decode all video- and audio frames in the video. Only pts Decode all video- and audio frames in the video. Only pts
(presentation timestamp) is returned. The actual frame pixel data is not (presentation timestamp) is returned. The actual frame pixel data is not
...@@ -252,7 +257,7 @@ def _read_video_timestamps_from_file(filename): ...@@ -252,7 +257,7 @@ def _read_video_timestamps_from_file(filename):
return vframe_pts, aframe_pts, info return vframe_pts, aframe_pts, info
def _probe_video_from_file(filename): def _probe_video_from_file(filename: str) -> VideoMetaData:
""" """
Probe a video file and return VideoMetaData with info about the video Probe a video file and return VideoMetaData with info about the video
""" """
...@@ -263,24 +268,23 @@ def _probe_video_from_file(filename): ...@@ -263,24 +268,23 @@ def _probe_video_from_file(filename):
def _read_video_from_memory( def _read_video_from_memory(
video_data, # type: torch.Tensor video_data: torch.Tensor,
seek_frame_margin=0.25, # type: float seek_frame_margin: float = 0.25,
read_video_stream=1, # type: int read_video_stream: int = 1,
video_width=0, # type: int video_width: int = 0,
video_height=0, # type: int video_height: int = 0,
video_min_dimension=0, # type: int video_min_dimension: int = 0,
video_max_dimension=0, # type: int video_max_dimension: int = 0,
video_pts_range=(0, -1), # type: List[int] video_pts_range: Tuple[int, int] = (0, -1),
video_timebase_numerator=0, # type: int video_timebase_numerator: int = 0,
video_timebase_denominator=1, # type: int video_timebase_denominator: int = 1,
read_audio_stream=1, # type: int read_audio_stream: int = 1,
audio_samples=0, # type: int audio_samples: int = 0,
audio_channels=0, # type: int audio_channels: int = 0,
audio_pts_range=(0, -1), # type: List[int] audio_pts_range: Tuple[int, int] = (0, -1),
audio_timebase_numerator=0, # type: int audio_timebase_numerator: int = 0,
audio_timebase_denominator=1, # type: int audio_timebase_denominator: int = 1,
): ) -> Tuple[torch.Tensor, torch.Tensor]:
# type: (...) -> Tuple[torch.Tensor, torch.Tensor]
""" """
Reads a video from memory, returning both the video frames as well as Reads a video from memory, returning both the video frames as well as
the audio frames the audio frames
...@@ -370,7 +374,9 @@ def _read_video_from_memory( ...@@ -370,7 +374,9 @@ def _read_video_from_memory(
return vframes, aframes return vframes, aframes
def _read_video_timestamps_from_memory(video_data): def _read_video_timestamps_from_memory(
video_data: torch.Tensor,
) -> Tuple[List[int], List[int], VideoMetaData]:
""" """
Decode all frames in the video. Only pts (presentation timestamp) is returned. Decode all frames in the video. Only pts (presentation timestamp) is returned.
The actual frame pixel data is not copied. Thus, read_video_timestamps(...) The actual frame pixel data is not copied. Thus, read_video_timestamps(...)
...@@ -407,8 +413,9 @@ def _read_video_timestamps_from_memory(video_data): ...@@ -407,8 +413,9 @@ def _read_video_timestamps_from_memory(video_data):
return vframe_pts, aframe_pts, info return vframe_pts, aframe_pts, info
def _probe_video_from_memory(video_data): def _probe_video_from_memory(
# type: (torch.Tensor) -> VideoMetaData video_data: torch.Tensor,
) -> VideoMetaData:
""" """
Probe a video in memory and return VideoMetaData with info about the video Probe a video in memory and return VideoMetaData with info about the video
This function is torchscriptable This function is torchscriptable
...@@ -421,7 +428,9 @@ def _probe_video_from_memory(video_data): ...@@ -421,7 +428,9 @@ def _probe_video_from_memory(video_data):
return info return info
def _convert_to_sec(start_pts, end_pts, pts_unit, time_base): def _convert_to_sec(
start_pts: Union[float, Fraction], end_pts: Union[float, Fraction], pts_unit: str, time_base: Fraction
) -> Tuple[Union[float, Fraction], Union[float, Fraction], str]:
if pts_unit == "pts": if pts_unit == "pts":
start_pts = float(start_pts * time_base) start_pts = float(start_pts * time_base)
end_pts = float(end_pts * time_base) end_pts = float(end_pts * time_base)
...@@ -429,7 +438,12 @@ def _convert_to_sec(start_pts, end_pts, pts_unit, time_base): ...@@ -429,7 +438,12 @@ def _convert_to_sec(start_pts, end_pts, pts_unit, time_base):
return start_pts, end_pts, pts_unit return start_pts, end_pts, pts_unit
def _read_video(filename, start_pts=0, end_pts=None, pts_unit="pts"): def _read_video(
filename: str,
start_pts: Union[float, Fraction] = 0,
end_pts: Optional[Union[float, Fraction]] = None,
pts_unit: str = "pts",
) -> Tuple[torch.Tensor, torch.Tensor, Dict[str, float]]:
if end_pts is None: if end_pts is None:
end_pts = float("inf") end_pts = float("inf")
...@@ -495,13 +509,16 @@ def _read_video(filename, start_pts=0, end_pts=None, pts_unit="pts"): ...@@ -495,13 +509,16 @@ def _read_video(filename, start_pts=0, end_pts=None, pts_unit="pts"):
return vframes, aframes, _info return vframes, aframes, _info
def _read_video_timestamps(filename, pts_unit="pts"): def _read_video_timestamps(
filename: str, pts_unit: str = "pts"
) -> Tuple[Union[List[int], List[Fraction]], Optional[float]]:
if pts_unit == "pts": if pts_unit == "pts":
warnings.warn( warnings.warn(
"The pts_unit 'pts' gives wrong results and will be removed in a " "The pts_unit 'pts' gives wrong results and will be removed in a "
+ "follow-up version. Please use pts_unit 'sec'." + "follow-up version. Please use pts_unit 'sec'."
) )
pts: Union[List[int], List[Fraction]]
pts, _, info = _read_video_timestamps_from_file(filename) pts, _, info = _read_video_timestamps_from_file(filename)
if pts_unit == "sec": if pts_unit == "sec":
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment