Unverified Commit f8bf06d5 authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

Add typehints for torchvision.io (#2543)



* enable typing check for torchvision.io

* fix existing errors

* Update torchvision/io/_video_opt.py
Co-authored-by: default avatarFrancisco Massa <fvsmassa@gmail.com>

* add ignores for FileFinder

* use python 3 type hints

* lint

* video_opt

* video

* try quote av type hints

* revert from .dim() to .ndim

* revert changes to _video_opt.py and ignore errors

* fix type hints

* fix type hints for read_video_timestamps

* change offset int to float

* remove unused import
Co-authored-by: default avatarFrancisco Massa <fvsmassa@gmail.com>
parent 854ead08
......@@ -8,7 +8,7 @@ pretty = True
;ignore_errors = True
[mypy-torchvision.io.*]
[mypy-torchvision.io._video_opt.*]
ignore_errors = True
......@@ -51,3 +51,7 @@ ignore_missing_imports = True
[mypy-accimage.*]
ignore_missing_imports = True
[mypy-av.*]
ignore_missing_imports = True
import torch
from torch import nn, Tensor
import os
import os.path as osp
import importlib
import importlib.machinery
_HAS_IMAGE_OPT = False
......@@ -15,7 +14,7 @@ try:
importlib.machinery.EXTENSION_SUFFIXES
)
extfinder = importlib.machinery.FileFinder(lib_dir, loader_details)
extfinder = importlib.machinery.FileFinder(lib_dir, loader_details) # type: ignore[arg-type]
ext_specs = extfinder.find_spec("image")
if ext_specs is not None:
torch.ops.load_library(ext_specs.origin)
......@@ -24,8 +23,7 @@ except (ImportError, OSError):
pass
def decode_png(input):
# type: (Tensor) -> Tensor
def decode_png(input: torch.Tensor) -> torch.Tensor:
"""
Decodes a PNG image into a 3 dimensional RGB Tensor.
The values of the output tensor are uint8 between 0 and 255.
......@@ -37,7 +35,7 @@ def decode_png(input):
Returns:
output (Tensor[image_width, image_height, 3])
"""
if not isinstance(input, torch.Tensor) or input.numel() == 0 or input.ndim != 1:
if not isinstance(input, torch.Tensor) or input.numel() == 0 or input.ndim != 1: # type: ignore[attr-defined]
raise ValueError("Expected a non empty 1-dimensional tensor.")
if not input.dtype == torch.uint8:
......@@ -46,8 +44,7 @@ def decode_png(input):
return output
def read_png(path):
# type: (str) -> Tensor
def read_png(path: str) -> torch.Tensor:
"""
Reads a PNG image into a 3 dimensional RGB Tensor.
The values of the output tensor are uint8 between 0 and 255.
......@@ -68,8 +65,7 @@ def read_png(path):
return decode_png(data)
def decode_jpeg(input):
# type: (Tensor) -> Tensor
def decode_jpeg(input: torch.Tensor) -> torch.Tensor:
"""
Decodes a JPEG image into a 3 dimensional RGB Tensor.
The values of the output tensor are uint8 between 0 and 255.
......@@ -79,7 +75,7 @@ def decode_jpeg(input):
Returns:
output (Tensor[image_width, image_height, 3])
"""
if not isinstance(input, torch.Tensor) or len(input) == 0 or input.ndim != 1:
if not isinstance(input, torch.Tensor) or len(input) == 0 or input.ndim != 1: # type: ignore[attr-defined]
raise ValueError("Expected a non empty 1-dimensional tensor.")
if not input.dtype == torch.uint8:
......@@ -89,8 +85,7 @@ def decode_jpeg(input):
return output
def read_jpeg(path):
# type: (str) -> Tensor
def read_jpeg(path: str) -> torch.Tensor:
"""
Reads a JPEG image into a 3 dimensional RGB Tensor.
The values of the output tensor are uint8 between 0 and 255.
......
......@@ -2,7 +2,7 @@ import gc
import math
import re
import warnings
from typing import List, Tuple, Union
from typing import Any, Dict, List, Optional, Tuple, Union
import numpy as np
import torch
......@@ -35,12 +35,12 @@ install PyAV on your system.
)
def _check_av_available():
def _check_av_available() -> None:
if isinstance(av, Exception):
raise av
def _av_available():
def _av_available() -> bool:
return not isinstance(av, Exception)
......@@ -49,7 +49,13 @@ _CALLED_TIMES = 0
_GC_COLLECTION_INTERVAL = 10
def write_video(filename, video_array, fps: Union[int, float], video_codec="libx264", options=None):
def write_video(
filename: str,
video_array: torch.Tensor,
fps: float,
video_codec: str = "libx264",
options: Optional[Dict[str, Any]] = None,
) -> None:
"""
Writes a 4d tensor in [T, H, W, C] format in a video file
......@@ -89,8 +95,13 @@ def write_video(filename, video_array, fps: Union[int, float], video_codec="libx
def _read_from_stream(
container, start_offset, end_offset, pts_unit, stream, stream_name
):
container: "av.container.Container",
start_offset: float,
end_offset: float,
pts_unit: str,
stream: "av.stream.Stream",
stream_name: Dict[str, Optional[Union[int, Tuple[int, ...], List[int]]]],
) -> List["av.frame.Frame"]:
global _CALLED_TIMES, _GC_COLLECTION_INTERVAL
_CALLED_TIMES += 1
if _CALLED_TIMES % _GC_COLLECTION_INTERVAL == _GC_COLLECTION_INTERVAL - 1:
......@@ -166,7 +177,9 @@ def _read_from_stream(
return result
def _align_audio_frames(aframes, audio_frames, ref_start, ref_end):
def _align_audio_frames(
aframes: torch.Tensor, audio_frames: List["av.frame.Frame"], ref_start: int, ref_end: float
) -> torch.Tensor:
start, end = audio_frames[0].pts, audio_frames[-1].pts
total_aframes = aframes.shape[1]
step_per_aframe = (end - start + 1) / total_aframes
......@@ -179,7 +192,9 @@ def _align_audio_frames(aframes, audio_frames, ref_start, ref_end):
return aframes[:, s_idx:e_idx]
def read_video(filename, start_pts=0, end_pts=None, pts_unit="pts"):
def read_video(
filename: str, start_pts: int = 0, end_pts: Optional[float] = None, pts_unit: str = "pts"
) -> Tuple[torch.Tensor, torch.Tensor, Dict[str, Any]]:
"""
Reads a video from a file, returning both the video frames as well as
the audio frames
......@@ -260,16 +275,16 @@ def read_video(filename, start_pts=0, end_pts=None, pts_unit="pts"):
# TODO raise a warning?
pass
vframes = [frame.to_rgb().to_ndarray() for frame in video_frames]
aframes = [frame.to_ndarray() for frame in audio_frames]
vframes_list = [frame.to_rgb().to_ndarray() for frame in video_frames]
aframes_list = [frame.to_ndarray() for frame in audio_frames]
if vframes:
vframes = torch.as_tensor(np.stack(vframes))
if vframes_list:
vframes = torch.as_tensor(np.stack(vframes_list))
else:
vframes = torch.empty((0, 1, 1, 3), dtype=torch.uint8)
if aframes:
aframes = np.concatenate(aframes, 1)
if aframes_list:
aframes = np.concatenate(aframes_list, 1)
aframes = torch.as_tensor(aframes)
aframes = _align_audio_frames(aframes, audio_frames, start_pts, end_pts)
else:
......@@ -278,7 +293,7 @@ def read_video(filename, start_pts=0, end_pts=None, pts_unit="pts"):
return vframes, aframes, info
def _can_read_timestamps_from_packets(container):
def _can_read_timestamps_from_packets(container: "av.container.Container") -> bool:
extradata = container.streams[0].codec_context.extradata
if extradata is None:
return False
......@@ -287,7 +302,7 @@ def _can_read_timestamps_from_packets(container):
return False
def _decode_video_timestamps(container):
def _decode_video_timestamps(container: "av.container.Container") -> List[int]:
if _can_read_timestamps_from_packets(container):
# fast path
return [x.pts for x in container.demux(video=0) if x.pts is not None]
......@@ -295,7 +310,7 @@ def _decode_video_timestamps(container):
return [x.pts for x in container.decode(video=0) if x.pts is not None]
def read_video_timestamps(filename, pts_unit="pts"):
def read_video_timestamps(filename: str, pts_unit: str = "pts") -> Tuple[List[int], Optional[float]]:
"""
List the video frames timestamps.
......@@ -313,7 +328,7 @@ def read_video_timestamps(filename, pts_unit="pts"):
pts : List[int] if pts_unit = 'pts'
List[Fraction] if pts_unit = 'sec'
presentation timestamps for each one of the frames in the video.
video_fps : int
video_fps : float, optional
the frame rate for the video
"""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment