Unverified Commit f8bf06d5 authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

Add typehints for torchvision.io (#2543)



* enable typing check for torchvision.io

* fix existing errors

* Update torchvision/io/_video_opt.py
Co-authored-by: default avatarFrancisco Massa <fvsmassa@gmail.com>

* add ignores for FileFinder

* use python 3 type hints

* lint

* video_opt

* video

* try quote av type hints

* revert from .dim() to .ndim

* revert changes to _video_opt.py and ignore errors

* fix type hints

* fix type hints for read_video_timestamps

* change offset int to float

* remove unused import
Co-authored-by: default avatarFrancisco Massa <fvsmassa@gmail.com>
parent 854ead08
...@@ -8,7 +8,7 @@ pretty = True ...@@ -8,7 +8,7 @@ pretty = True
;ignore_errors = True ;ignore_errors = True
[mypy-torchvision.io.*] [mypy-torchvision.io._video_opt.*]
ignore_errors = True ignore_errors = True
...@@ -51,3 +51,7 @@ ignore_missing_imports = True ...@@ -51,3 +51,7 @@ ignore_missing_imports = True
[mypy-accimage.*] [mypy-accimage.*]
ignore_missing_imports = True ignore_missing_imports = True
[mypy-av.*]
ignore_missing_imports = True
import torch import torch
from torch import nn, Tensor
import os import os
import os.path as osp import os.path as osp
import importlib import importlib.machinery
_HAS_IMAGE_OPT = False _HAS_IMAGE_OPT = False
...@@ -15,7 +14,7 @@ try: ...@@ -15,7 +14,7 @@ try:
importlib.machinery.EXTENSION_SUFFIXES importlib.machinery.EXTENSION_SUFFIXES
) )
extfinder = importlib.machinery.FileFinder(lib_dir, loader_details) extfinder = importlib.machinery.FileFinder(lib_dir, loader_details) # type: ignore[arg-type]
ext_specs = extfinder.find_spec("image") ext_specs = extfinder.find_spec("image")
if ext_specs is not None: if ext_specs is not None:
torch.ops.load_library(ext_specs.origin) torch.ops.load_library(ext_specs.origin)
...@@ -24,8 +23,7 @@ except (ImportError, OSError): ...@@ -24,8 +23,7 @@ except (ImportError, OSError):
pass pass
def decode_png(input): def decode_png(input: torch.Tensor) -> torch.Tensor:
# type: (Tensor) -> Tensor
""" """
Decodes a PNG image into a 3 dimensional RGB Tensor. Decodes a PNG image into a 3 dimensional RGB Tensor.
The values of the output tensor are uint8 between 0 and 255. The values of the output tensor are uint8 between 0 and 255.
...@@ -37,7 +35,7 @@ def decode_png(input): ...@@ -37,7 +35,7 @@ def decode_png(input):
Returns: Returns:
output (Tensor[image_width, image_height, 3]) output (Tensor[image_width, image_height, 3])
""" """
if not isinstance(input, torch.Tensor) or input.numel() == 0 or input.ndim != 1: if not isinstance(input, torch.Tensor) or input.numel() == 0 or input.ndim != 1: # type: ignore[attr-defined]
raise ValueError("Expected a non empty 1-dimensional tensor.") raise ValueError("Expected a non empty 1-dimensional tensor.")
if not input.dtype == torch.uint8: if not input.dtype == torch.uint8:
...@@ -46,8 +44,7 @@ def decode_png(input): ...@@ -46,8 +44,7 @@ def decode_png(input):
return output return output
def read_png(path): def read_png(path: str) -> torch.Tensor:
# type: (str) -> Tensor
""" """
Reads a PNG image into a 3 dimensional RGB Tensor. Reads a PNG image into a 3 dimensional RGB Tensor.
The values of the output tensor are uint8 between 0 and 255. The values of the output tensor are uint8 between 0 and 255.
...@@ -68,8 +65,7 @@ def read_png(path): ...@@ -68,8 +65,7 @@ def read_png(path):
return decode_png(data) return decode_png(data)
def decode_jpeg(input): def decode_jpeg(input: torch.Tensor) -> torch.Tensor:
# type: (Tensor) -> Tensor
""" """
Decodes a JPEG image into a 3 dimensional RGB Tensor. Decodes a JPEG image into a 3 dimensional RGB Tensor.
The values of the output tensor are uint8 between 0 and 255. The values of the output tensor are uint8 between 0 and 255.
...@@ -79,7 +75,7 @@ def decode_jpeg(input): ...@@ -79,7 +75,7 @@ def decode_jpeg(input):
Returns: Returns:
output (Tensor[image_width, image_height, 3]) output (Tensor[image_width, image_height, 3])
""" """
if not isinstance(input, torch.Tensor) or len(input) == 0 or input.ndim != 1: if not isinstance(input, torch.Tensor) or len(input) == 0 or input.ndim != 1: # type: ignore[attr-defined]
raise ValueError("Expected a non empty 1-dimensional tensor.") raise ValueError("Expected a non empty 1-dimensional tensor.")
if not input.dtype == torch.uint8: if not input.dtype == torch.uint8:
...@@ -89,8 +85,7 @@ def decode_jpeg(input): ...@@ -89,8 +85,7 @@ def decode_jpeg(input):
return output return output
def read_jpeg(path): def read_jpeg(path: str) -> torch.Tensor:
# type: (str) -> Tensor
""" """
Reads a JPEG image into a 3 dimensional RGB Tensor. Reads a JPEG image into a 3 dimensional RGB Tensor.
The values of the output tensor are uint8 between 0 and 255. The values of the output tensor are uint8 between 0 and 255.
......
...@@ -2,7 +2,7 @@ import gc ...@@ -2,7 +2,7 @@ import gc
import math import math
import re import re
import warnings import warnings
from typing import List, Tuple, Union from typing import Any, Dict, List, Optional, Tuple, Union
import numpy as np import numpy as np
import torch import torch
...@@ -35,12 +35,12 @@ install PyAV on your system. ...@@ -35,12 +35,12 @@ install PyAV on your system.
) )
def _check_av_available(): def _check_av_available() -> None:
if isinstance(av, Exception): if isinstance(av, Exception):
raise av raise av
def _av_available(): def _av_available() -> bool:
return not isinstance(av, Exception) return not isinstance(av, Exception)
...@@ -49,7 +49,13 @@ _CALLED_TIMES = 0 ...@@ -49,7 +49,13 @@ _CALLED_TIMES = 0
_GC_COLLECTION_INTERVAL = 10 _GC_COLLECTION_INTERVAL = 10
def write_video(filename, video_array, fps: Union[int, float], video_codec="libx264", options=None): def write_video(
filename: str,
video_array: torch.Tensor,
fps: float,
video_codec: str = "libx264",
options: Optional[Dict[str, Any]] = None,
) -> None:
""" """
Writes a 4d tensor in [T, H, W, C] format in a video file Writes a 4d tensor in [T, H, W, C] format in a video file
...@@ -89,8 +95,13 @@ def write_video(filename, video_array, fps: Union[int, float], video_codec="libx ...@@ -89,8 +95,13 @@ def write_video(filename, video_array, fps: Union[int, float], video_codec="libx
def _read_from_stream( def _read_from_stream(
container, start_offset, end_offset, pts_unit, stream, stream_name container: "av.container.Container",
): start_offset: float,
end_offset: float,
pts_unit: str,
stream: "av.stream.Stream",
stream_name: Dict[str, Optional[Union[int, Tuple[int, ...], List[int]]]],
) -> List["av.frame.Frame"]:
global _CALLED_TIMES, _GC_COLLECTION_INTERVAL global _CALLED_TIMES, _GC_COLLECTION_INTERVAL
_CALLED_TIMES += 1 _CALLED_TIMES += 1
if _CALLED_TIMES % _GC_COLLECTION_INTERVAL == _GC_COLLECTION_INTERVAL - 1: if _CALLED_TIMES % _GC_COLLECTION_INTERVAL == _GC_COLLECTION_INTERVAL - 1:
...@@ -166,7 +177,9 @@ def _read_from_stream( ...@@ -166,7 +177,9 @@ def _read_from_stream(
return result return result
def _align_audio_frames(aframes, audio_frames, ref_start, ref_end): def _align_audio_frames(
aframes: torch.Tensor, audio_frames: List["av.frame.Frame"], ref_start: int, ref_end: float
) -> torch.Tensor:
start, end = audio_frames[0].pts, audio_frames[-1].pts start, end = audio_frames[0].pts, audio_frames[-1].pts
total_aframes = aframes.shape[1] total_aframes = aframes.shape[1]
step_per_aframe = (end - start + 1) / total_aframes step_per_aframe = (end - start + 1) / total_aframes
...@@ -179,7 +192,9 @@ def _align_audio_frames(aframes, audio_frames, ref_start, ref_end): ...@@ -179,7 +192,9 @@ def _align_audio_frames(aframes, audio_frames, ref_start, ref_end):
return aframes[:, s_idx:e_idx] return aframes[:, s_idx:e_idx]
def read_video(filename, start_pts=0, end_pts=None, pts_unit="pts"): def read_video(
filename: str, start_pts: int = 0, end_pts: Optional[float] = None, pts_unit: str = "pts"
) -> Tuple[torch.Tensor, torch.Tensor, Dict[str, Any]]:
""" """
Reads a video from a file, returning both the video frames as well as Reads a video from a file, returning both the video frames as well as
the audio frames the audio frames
...@@ -260,16 +275,16 @@ def read_video(filename, start_pts=0, end_pts=None, pts_unit="pts"): ...@@ -260,16 +275,16 @@ def read_video(filename, start_pts=0, end_pts=None, pts_unit="pts"):
# TODO raise a warning? # TODO raise a warning?
pass pass
vframes = [frame.to_rgb().to_ndarray() for frame in video_frames] vframes_list = [frame.to_rgb().to_ndarray() for frame in video_frames]
aframes = [frame.to_ndarray() for frame in audio_frames] aframes_list = [frame.to_ndarray() for frame in audio_frames]
if vframes: if vframes_list:
vframes = torch.as_tensor(np.stack(vframes)) vframes = torch.as_tensor(np.stack(vframes_list))
else: else:
vframes = torch.empty((0, 1, 1, 3), dtype=torch.uint8) vframes = torch.empty((0, 1, 1, 3), dtype=torch.uint8)
if aframes: if aframes_list:
aframes = np.concatenate(aframes, 1) aframes = np.concatenate(aframes_list, 1)
aframes = torch.as_tensor(aframes) aframes = torch.as_tensor(aframes)
aframes = _align_audio_frames(aframes, audio_frames, start_pts, end_pts) aframes = _align_audio_frames(aframes, audio_frames, start_pts, end_pts)
else: else:
...@@ -278,7 +293,7 @@ def read_video(filename, start_pts=0, end_pts=None, pts_unit="pts"): ...@@ -278,7 +293,7 @@ def read_video(filename, start_pts=0, end_pts=None, pts_unit="pts"):
return vframes, aframes, info return vframes, aframes, info
def _can_read_timestamps_from_packets(container): def _can_read_timestamps_from_packets(container: "av.container.Container") -> bool:
extradata = container.streams[0].codec_context.extradata extradata = container.streams[0].codec_context.extradata
if extradata is None: if extradata is None:
return False return False
...@@ -287,7 +302,7 @@ def _can_read_timestamps_from_packets(container): ...@@ -287,7 +302,7 @@ def _can_read_timestamps_from_packets(container):
return False return False
def _decode_video_timestamps(container): def _decode_video_timestamps(container: "av.container.Container") -> List[int]:
if _can_read_timestamps_from_packets(container): if _can_read_timestamps_from_packets(container):
# fast path # fast path
return [x.pts for x in container.demux(video=0) if x.pts is not None] return [x.pts for x in container.demux(video=0) if x.pts is not None]
...@@ -295,7 +310,7 @@ def _decode_video_timestamps(container): ...@@ -295,7 +310,7 @@ def _decode_video_timestamps(container):
return [x.pts for x in container.decode(video=0) if x.pts is not None] return [x.pts for x in container.decode(video=0) if x.pts is not None]
def read_video_timestamps(filename, pts_unit="pts"): def read_video_timestamps(filename: str, pts_unit: str = "pts") -> Tuple[List[int], Optional[float]]:
""" """
List the video frames timestamps. List the video frames timestamps.
...@@ -313,7 +328,7 @@ def read_video_timestamps(filename, pts_unit="pts"): ...@@ -313,7 +328,7 @@ def read_video_timestamps(filename, pts_unit="pts"):
pts : List[int] if pts_unit = 'pts' pts : List[int] if pts_unit = 'pts'
List[Fraction] if pts_unit = 'sec' List[Fraction] if pts_unit = 'sec'
presentation timestamps for each one of the frames in the video. presentation timestamps for each one of the frames in the video.
video_fps : int video_fps : float, optional
the frame rate for the video the frame rate for the video
""" """
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment