"git@developer.sourcefind.cn:OpenDAS/mmcv.git" did not exist on "987d34b0cf8d6cd8725258332fcfc8c54529b1ab"
Unverified Commit 42aa9b26 authored by Quentin Duval's avatar Quentin Duval Committed by GitHub
Browse files

Refactoring to use contexts managers, list comprehensions when more idiomatic,...

Refactoring to use contexts managers, list comprehensions when more idiomatic, and minor renaming to help reader clarity (#2335)

* Refactoring to use contexts managers, list comprehensions when more idiomatic, and minor renaming to help reader clarity.

* Fix flake8 warning in video_utils.py
parent 32f21dad
...@@ -95,16 +95,9 @@ def list_dir(root, prefix=False): ...@@ -95,16 +95,9 @@ def list_dir(root, prefix=False):
only returns the name of the directories found only returns the name of the directories found
""" """
root = os.path.expanduser(root) root = os.path.expanduser(root)
directories = list( directories = [p for p in os.listdir(root) if os.path.isdir(os.path.join(root, p))]
filter(
lambda p: os.path.isdir(os.path.join(root, p)),
os.listdir(root)
)
)
if prefix is True: if prefix is True:
directories = [os.path.join(root, d) for d in directories] directories = [os.path.join(root, d) for d in directories]
return directories return directories
...@@ -119,16 +112,9 @@ def list_files(root, suffix, prefix=False): ...@@ -119,16 +112,9 @@ def list_files(root, suffix, prefix=False):
only returns the name of the files found only returns the name of the files found
""" """
root = os.path.expanduser(root) root = os.path.expanduser(root)
files = list( files = [p for p in os.listdir(root) if os.path.isfile(os.path.join(root, p)) and p.endswith(suffix)]
filter(
lambda p: os.path.isfile(os.path.join(root, p)) and p.endswith(suffix),
os.listdir(root)
)
)
if prefix is True: if prefix is True:
files = [os.path.join(root, d) for d in files] files = [os.path.join(root, d) for d in files]
return files return files
......
import bisect import bisect
import math import math
from fractions import Fraction from fractions import Fraction
from typing import List
import torch import torch
from torchvision.io import ( from torchvision.io import (
...@@ -45,20 +46,23 @@ def unfold(tensor, size, step, dilation=1): ...@@ -45,20 +46,23 @@ def unfold(tensor, size, step, dilation=1):
return torch.as_strided(tensor, new_size, new_stride) return torch.as_strided(tensor, new_size, new_stride)
class _DummyDataset(object): class _VideoTimestampsDataset(object):
""" """
Dummy dataset used for DataLoader in VideoClips. Dataset used to parallelize the reading of the timestamps
Defined at top level so it can be pickled when forking. of a list of videos, given their paths in the filesystem.
Used in VideoClips and defined at top level so it can be
pickled when forking.
""" """
def __init__(self, x): def __init__(self, video_paths: List[str]):
self.x = x self.video_paths = video_paths
def __len__(self): def __len__(self):
return len(self.x) return len(self.video_paths)
def __getitem__(self, idx): def __getitem__(self, idx):
return read_video_timestamps(self.x[idx]) return read_video_timestamps(self.video_paths[idx])
class VideoClips(object): class VideoClips(object):
...@@ -132,7 +136,7 @@ class VideoClips(object): ...@@ -132,7 +136,7 @@ class VideoClips(object):
import torch.utils.data import torch.utils.data
dl = torch.utils.data.DataLoader( dl = torch.utils.data.DataLoader(
_DummyDataset(self.video_paths), _VideoTimestampsDataset(self.video_paths),
batch_size=16, batch_size=16,
num_workers=self.num_workers, num_workers=self.num_workers,
collate_fn=self._collate_fn, collate_fn=self._collate_fn,
......
...@@ -70,27 +70,23 @@ def write_video(filename, video_array, fps: Union[int, float], video_codec="libx ...@@ -70,27 +70,23 @@ def write_video(filename, video_array, fps: Union[int, float], video_codec="libx
if isinstance(fps, float): if isinstance(fps, float):
fps = np.round(fps) fps = np.round(fps)
container = av.open(filename, mode="w") with av.open(filename, mode="w") as container:
stream = container.add_stream(video_codec, rate=fps)
stream = container.add_stream(video_codec, rate=fps) stream.width = video_array.shape[2]
stream.width = video_array.shape[2] stream.height = video_array.shape[1]
stream.height = video_array.shape[1] stream.pix_fmt = "yuv420p" if video_codec != "libx264rgb" else "rgb24"
stream.pix_fmt = "yuv420p" if video_codec != "libx264rgb" else "rgb24" stream.options = options or {}
stream.options = options or {}
for img in video_array:
for img in video_array: frame = av.VideoFrame.from_ndarray(img, format="rgb24")
frame = av.VideoFrame.from_ndarray(img, format="rgb24") frame.pict_type = "NONE"
frame.pict_type = "NONE" for packet in stream.encode(frame):
for packet in stream.encode(frame): container.mux(packet)
# Flush stream
for packet in stream.encode():
container.mux(packet) container.mux(packet)
# Flush stream
for packet in stream.encode():
container.mux(packet)
# Close the file
container.close()
def _read_from_stream( def _read_from_stream(
container, start_offset, end_offset, pts_unit, stream, stream_name container, start_offset, end_offset, pts_unit, stream, stream_name
...@@ -234,37 +230,35 @@ def read_video(filename, start_pts=0, end_pts=None, pts_unit="pts"): ...@@ -234,37 +230,35 @@ def read_video(filename, start_pts=0, end_pts=None, pts_unit="pts"):
audio_frames = [] audio_frames = []
try: try:
container = av.open(filename, metadata_errors="ignore") with av.open(filename, metadata_errors="ignore") as container:
if container.streams.video:
video_frames = _read_from_stream(
container,
start_pts,
end_pts,
pts_unit,
container.streams.video[0],
{"video": 0},
)
video_fps = container.streams.video[0].average_rate
# guard against potentially corrupted files
if video_fps is not None:
info["video_fps"] = float(video_fps)
if container.streams.audio:
audio_frames = _read_from_stream(
container,
start_pts,
end_pts,
pts_unit,
container.streams.audio[0],
{"audio": 0},
)
info["audio_fps"] = container.streams.audio[0].rate
except av.AVError: except av.AVError:
# TODO raise a warning? # TODO raise a warning?
pass pass
else:
if container.streams.video:
video_frames = _read_from_stream(
container,
start_pts,
end_pts,
pts_unit,
container.streams.video[0],
{"video": 0},
)
video_fps = container.streams.video[0].average_rate
# guard against potentially corrupted files
if video_fps is not None:
info["video_fps"] = float(video_fps)
if container.streams.audio:
audio_frames = _read_from_stream(
container,
start_pts,
end_pts,
pts_unit,
container.streams.audio[0],
{"audio": 0},
)
info["audio_fps"] = container.streams.audio[0].rate
container.close()
vframes = [frame.to_rgb().to_ndarray() for frame in video_frames] vframes = [frame.to_rgb().to_ndarray() for frame in video_frames]
aframes = [frame.to_ndarray() for frame in audio_frames] aframes = [frame.to_ndarray() for frame in audio_frames]
...@@ -293,6 +287,14 @@ def _can_read_timestamps_from_packets(container): ...@@ -293,6 +287,14 @@ def _can_read_timestamps_from_packets(container):
return False return False
def _decode_video_timestamps(container):
if _can_read_timestamps_from_packets(container):
# fast path
return [x.pts for x in container.demux(video=0) if x.pts is not None]
else:
return [x.pts for x in container.decode(video=0) if x.pts is not None]
def read_video_timestamps(filename, pts_unit="pts"): def read_video_timestamps(filename, pts_unit="pts"):
""" """
List the video frames timestamps. List the video frames timestamps.
...@@ -326,26 +328,18 @@ def read_video_timestamps(filename, pts_unit="pts"): ...@@ -326,26 +328,18 @@ def read_video_timestamps(filename, pts_unit="pts"):
pts = [] pts = []
try: try:
container = av.open(filename, metadata_errors="ignore") with av.open(filename, metadata_errors="ignore") as container:
if container.streams.video:
video_stream = container.streams.video[0]
video_time_base = video_stream.time_base
try:
pts = _decode_video_timestamps(container)
except av.AVError:
warnings.warn(f"Failed decoding frames for file {filename}")
video_fps = float(video_stream.average_rate)
except av.AVError: except av.AVError:
# TODO add a warning # TODO add a warning
pass pass
else:
if container.streams.video:
video_stream = container.streams.video[0]
video_time_base = video_stream.time_base
try:
if _can_read_timestamps_from_packets(container):
# fast path
pts = [x.pts for x in container.demux(video=0) if x.pts is not None]
else:
pts = [
x.pts for x in container.decode(video=0) if x.pts is not None
]
except av.AVError:
warnings.warn(f"Failed decoding frames for file {filename}")
video_fps = float(video_stream.average_rate)
container.close()
pts.sort() pts.sort()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment