Unverified Commit 42aa9b26 authored by Quentin Duval's avatar Quentin Duval Committed by GitHub
Browse files

Refactoring to use contexts managers, list comprehensions when more idiomatic,...

Refactoring to use contexts managers, list comprehensions when more idiomatic, and minor renaming to help reader clarity (#2335)

* Refactoring to use contexts managers, list comprehensions when more idiomatic, and minor renaming to help reader clarity.

* Fix flake8 warning in video_utils.py
parent 32f21dad
......@@ -95,16 +95,9 @@ def list_dir(root, prefix=False):
only returns the name of the directories found
"""
root = os.path.expanduser(root)
directories = list(
filter(
lambda p: os.path.isdir(os.path.join(root, p)),
os.listdir(root)
)
)
directories = [p for p in os.listdir(root) if os.path.isdir(os.path.join(root, p))]
if prefix is True:
directories = [os.path.join(root, d) for d in directories]
return directories
......@@ -119,16 +112,9 @@ def list_files(root, suffix, prefix=False):
only returns the name of the files found
"""
root = os.path.expanduser(root)
files = list(
filter(
lambda p: os.path.isfile(os.path.join(root, p)) and p.endswith(suffix),
os.listdir(root)
)
)
files = [p for p in os.listdir(root) if os.path.isfile(os.path.join(root, p)) and p.endswith(suffix)]
if prefix is True:
files = [os.path.join(root, d) for d in files]
return files
......
import bisect
import math
from fractions import Fraction
from typing import List
import torch
from torchvision.io import (
......@@ -45,20 +46,23 @@ def unfold(tensor, size, step, dilation=1):
return torch.as_strided(tensor, new_size, new_stride)
class _DummyDataset(object):
class _VideoTimestampsDataset(object):
"""
Dummy dataset used for DataLoader in VideoClips.
Defined at top level so it can be pickled when forking.
Dataset used to parallelize the reading of the timestamps
of a list of videos, given their paths in the filesystem.
Used in VideoClips and defined at top level so it can be
pickled when forking.
"""
def __init__(self, x):
self.x = x
def __init__(self, video_paths: List[str]):
self.video_paths = video_paths
def __len__(self):
return len(self.x)
return len(self.video_paths)
def __getitem__(self, idx):
return read_video_timestamps(self.x[idx])
return read_video_timestamps(self.video_paths[idx])
class VideoClips(object):
......@@ -132,7 +136,7 @@ class VideoClips(object):
import torch.utils.data
dl = torch.utils.data.DataLoader(
_DummyDataset(self.video_paths),
_VideoTimestampsDataset(self.video_paths),
batch_size=16,
num_workers=self.num_workers,
collate_fn=self._collate_fn,
......
......@@ -70,27 +70,23 @@ def write_video(filename, video_array, fps: Union[int, float], video_codec="libx
if isinstance(fps, float):
fps = np.round(fps)
container = av.open(filename, mode="w")
stream = container.add_stream(video_codec, rate=fps)
stream.width = video_array.shape[2]
stream.height = video_array.shape[1]
stream.pix_fmt = "yuv420p" if video_codec != "libx264rgb" else "rgb24"
stream.options = options or {}
for img in video_array:
frame = av.VideoFrame.from_ndarray(img, format="rgb24")
frame.pict_type = "NONE"
for packet in stream.encode(frame):
with av.open(filename, mode="w") as container:
stream = container.add_stream(video_codec, rate=fps)
stream.width = video_array.shape[2]
stream.height = video_array.shape[1]
stream.pix_fmt = "yuv420p" if video_codec != "libx264rgb" else "rgb24"
stream.options = options or {}
for img in video_array:
frame = av.VideoFrame.from_ndarray(img, format="rgb24")
frame.pict_type = "NONE"
for packet in stream.encode(frame):
container.mux(packet)
# Flush stream
for packet in stream.encode():
container.mux(packet)
# Flush stream
for packet in stream.encode():
container.mux(packet)
# Close the file
container.close()
def _read_from_stream(
container, start_offset, end_offset, pts_unit, stream, stream_name
......@@ -234,37 +230,35 @@ def read_video(filename, start_pts=0, end_pts=None, pts_unit="pts"):
audio_frames = []
try:
container = av.open(filename, metadata_errors="ignore")
with av.open(filename, metadata_errors="ignore") as container:
if container.streams.video:
video_frames = _read_from_stream(
container,
start_pts,
end_pts,
pts_unit,
container.streams.video[0],
{"video": 0},
)
video_fps = container.streams.video[0].average_rate
# guard against potentially corrupted files
if video_fps is not None:
info["video_fps"] = float(video_fps)
if container.streams.audio:
audio_frames = _read_from_stream(
container,
start_pts,
end_pts,
pts_unit,
container.streams.audio[0],
{"audio": 0},
)
info["audio_fps"] = container.streams.audio[0].rate
except av.AVError:
# TODO raise a warning?
pass
else:
if container.streams.video:
video_frames = _read_from_stream(
container,
start_pts,
end_pts,
pts_unit,
container.streams.video[0],
{"video": 0},
)
video_fps = container.streams.video[0].average_rate
# guard against potentially corrupted files
if video_fps is not None:
info["video_fps"] = float(video_fps)
if container.streams.audio:
audio_frames = _read_from_stream(
container,
start_pts,
end_pts,
pts_unit,
container.streams.audio[0],
{"audio": 0},
)
info["audio_fps"] = container.streams.audio[0].rate
container.close()
vframes = [frame.to_rgb().to_ndarray() for frame in video_frames]
aframes = [frame.to_ndarray() for frame in audio_frames]
......@@ -293,6 +287,14 @@ def _can_read_timestamps_from_packets(container):
return False
def _decode_video_timestamps(container):
if _can_read_timestamps_from_packets(container):
# fast path
return [x.pts for x in container.demux(video=0) if x.pts is not None]
else:
return [x.pts for x in container.decode(video=0) if x.pts is not None]
def read_video_timestamps(filename, pts_unit="pts"):
"""
List the video frames timestamps.
......@@ -326,26 +328,18 @@ def read_video_timestamps(filename, pts_unit="pts"):
pts = []
try:
container = av.open(filename, metadata_errors="ignore")
with av.open(filename, metadata_errors="ignore") as container:
if container.streams.video:
video_stream = container.streams.video[0]
video_time_base = video_stream.time_base
try:
pts = _decode_video_timestamps(container)
except av.AVError:
warnings.warn(f"Failed decoding frames for file {filename}")
video_fps = float(video_stream.average_rate)
except av.AVError:
# TODO add a warning
pass
else:
if container.streams.video:
video_stream = container.streams.video[0]
video_time_base = video_stream.time_base
try:
if _can_read_timestamps_from_packets(container):
# fast path
pts = [x.pts for x in container.demux(video=0) if x.pts is not None]
else:
pts = [
x.pts for x in container.decode(video=0) if x.pts is not None
]
except av.AVError:
warnings.warn(f"Failed decoding frames for file {filename}")
video_fps = float(video_stream.average_rate)
container.close()
pts.sort()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment