Unverified Commit 7d509c5d authored by Francisco Massa's avatar Francisco Massa Committed by GitHub
Browse files

Unify video metadata in VideoClips (#1527)

* Unify video metadata in VideoClips

* Bugfix

* Make tests a bit more robust
parent c226bb95
...@@ -59,7 +59,7 @@ class Tester(unittest.TestCase): ...@@ -59,7 +59,7 @@ class Tester(unittest.TestCase):
self.assertTrue(r.equal(expected)) self.assertTrue(r.equal(expected))
@unittest.skipIf(not io.video._av_available(), "this test requires av") @unittest.skipIf(not io.video._av_available(), "this test requires av")
@unittest.skipIf('win' in sys.platform, 'temporarily disabled on Windows') @unittest.skipIf(sys.platform == 'win32', 'temporarily disabled on Windows')
def test_video_clips(self): def test_video_clips(self):
with get_list_of_videos(num_videos=3) as video_list: with get_list_of_videos(num_videos=3) as video_list:
video_clips = VideoClips(video_list, 5, 5) video_clips = VideoClips(video_list, 5, 5)
...@@ -84,7 +84,7 @@ class Tester(unittest.TestCase): ...@@ -84,7 +84,7 @@ class Tester(unittest.TestCase):
self.assertEqual(clip_idx, c_idx) self.assertEqual(clip_idx, c_idx)
@unittest.skipIf(not io.video._av_available(), "this test requires av") @unittest.skipIf(not io.video._av_available(), "this test requires av")
@unittest.skipIf('win' in sys.platform, 'temporarily disabled on Windows') @unittest.skipIf(sys.platform == 'win32', 'temporarily disabled on Windows')
def test_video_clips_custom_fps(self): def test_video_clips_custom_fps(self):
with get_list_of_videos(num_videos=3, sizes=[12, 12, 12], fps=[3, 4, 6]) as video_list: with get_list_of_videos(num_videos=3, sizes=[12, 12, 12], fps=[3, 4, 6]) as video_list:
num_frames = 4 num_frames = 4
...@@ -94,6 +94,7 @@ class Tester(unittest.TestCase): ...@@ -94,6 +94,7 @@ class Tester(unittest.TestCase):
video, audio, info, video_idx = video_clips.get_clip(i) video, audio, info, video_idx = video_clips.get_clip(i)
self.assertEqual(video.shape[0], num_frames) self.assertEqual(video.shape[0], num_frames)
self.assertEqual(info["video_fps"], fps) self.assertEqual(info["video_fps"], fps)
self.assertEqual(info, {"video_fps": fps})
# TODO add tests checking that the content is right # TODO add tests checking that the content is right
def test_compute_clips_for_video(self): def test_compute_clips_for_video(self):
......
import unittest
from torchvision import set_video_backend
import test_datasets_video_utils
set_video_backend('video_reader')
if __name__ == '__main__':
suite = unittest.TestLoader().loadTestsFromModule(test_datasets_video_utils)
unittest.TextTestRunner(verbosity=1).run(suite)
...@@ -183,6 +183,7 @@ class Tester(unittest.TestCase): ...@@ -183,6 +183,7 @@ class Tester(unittest.TestCase):
self.assertTrue(data.equal(lv)) self.assertTrue(data.equal(lv))
self.assertEqual(info["video_fps"], 5) self.assertEqual(info["video_fps"], 5)
self.assertEqual(info, {"video_fps": 5})
def test_read_timestamps_pts_unit_sec(self): def test_read_timestamps_pts_unit_sec(self):
with temp_video(10, 300, 300, 5) as (f_name, data): with temp_video(10, 300, 300, 5) as (f_name, data):
......
...@@ -5,6 +5,7 @@ import torch ...@@ -5,6 +5,7 @@ import torch
from torchvision.io import ( from torchvision.io import (
_read_video_timestamps_from_file, _read_video_timestamps_from_file,
_read_video_from_file, _read_video_from_file,
_probe_video_from_file
) )
from torchvision.io import read_video_timestamps, read_video from torchvision.io import read_video_timestamps, read_video
...@@ -71,11 +72,11 @@ class VideoClips(object): ...@@ -71,11 +72,11 @@ class VideoClips(object):
frame_rate=None, _precomputed_metadata=None, num_workers=0, frame_rate=None, _precomputed_metadata=None, num_workers=0,
_video_width=0, _video_height=0, _video_min_dimension=0, _video_width=0, _video_height=0, _video_min_dimension=0,
_audio_samples=0): _audio_samples=0):
from torchvision import get_video_backend
self.video_paths = video_paths self.video_paths = video_paths
self.num_workers = num_workers self.num_workers = num_workers
self._backend = get_video_backend()
# these options are not valid for pyav backend
self._video_width = _video_width self._video_width = _video_width
self._video_height = _video_height self._video_height = _video_height
self._video_min_dimension = _video_min_dimension self._video_min_dimension = _video_min_dimension
...@@ -89,30 +90,23 @@ class VideoClips(object): ...@@ -89,30 +90,23 @@ class VideoClips(object):
def _compute_frame_pts(self): def _compute_frame_pts(self):
self.video_pts = [] self.video_pts = []
if self._backend == "pyav":
self.video_fps = [] self.video_fps = []
else:
self.info = []
# strategy: use a DataLoader to parallelize read_video_timestamps # strategy: use a DataLoader to parallelize read_video_timestamps
# so need to create a dummy dataset first # so need to create a dummy dataset first
class DS(object): class DS(object):
def __init__(self, x, _backend): def __init__(self, x):
self.x = x self.x = x
self._backend = _backend
def __len__(self): def __len__(self):
return len(self.x) return len(self.x)
def __getitem__(self, idx): def __getitem__(self, idx):
if self._backend == "pyav":
return read_video_timestamps(self.x[idx]) return read_video_timestamps(self.x[idx])
else:
return _read_video_timestamps_from_file(self.x[idx])
import torch.utils.data import torch.utils.data
dl = torch.utils.data.DataLoader( dl = torch.utils.data.DataLoader(
DS(self.video_paths, self._backend), DS(self.video_paths),
batch_size=16, batch_size=16,
num_workers=self.num_workers, num_workers=self.num_workers,
collate_fn=lambda x: x) collate_fn=lambda x: x)
...@@ -120,56 +114,36 @@ class VideoClips(object): ...@@ -120,56 +114,36 @@ class VideoClips(object):
with tqdm(total=len(dl)) as pbar: with tqdm(total=len(dl)) as pbar:
for batch in dl: for batch in dl:
pbar.update(1) pbar.update(1)
if self._backend == "pyav":
clips, fps = list(zip(*batch)) clips, fps = list(zip(*batch))
clips = [torch.as_tensor(c) for c in clips] clips = [torch.as_tensor(c) for c in clips]
self.video_pts.extend(clips) self.video_pts.extend(clips)
self.video_fps.extend(fps) self.video_fps.extend(fps)
else:
video_pts, _audio_pts, info = list(zip(*batch))
video_pts = [torch.as_tensor(c) for c in video_pts]
self.video_pts.extend(video_pts)
self.info.extend(info)
def _init_from_metadata(self, metadata): def _init_from_metadata(self, metadata):
self.video_paths = metadata["video_paths"] self.video_paths = metadata["video_paths"]
assert len(self.video_paths) == len(metadata["video_pts"]) assert len(self.video_paths) == len(metadata["video_pts"])
self.video_pts = metadata["video_pts"] self.video_pts = metadata["video_pts"]
if self._backend == "pyav":
assert len(self.video_paths) == len(metadata["video_fps"]) assert len(self.video_paths) == len(metadata["video_fps"])
self.video_fps = metadata["video_fps"] self.video_fps = metadata["video_fps"]
else:
assert len(self.video_paths) == len(metadata["info"])
self.info = metadata["info"]
@property @property
def metadata(self): def metadata(self):
_metadata = { _metadata = {
"video_paths": self.video_paths, "video_paths": self.video_paths,
"video_pts": self.video_pts, "video_pts": self.video_pts,
"video_fps": self.video_fps
} }
if self._backend == "pyav":
_metadata.update({"video_fps": self.video_fps})
else:
_metadata.update({"info": self.info})
return _metadata return _metadata
def subset(self, indices): def subset(self, indices):
video_paths = [self.video_paths[i] for i in indices] video_paths = [self.video_paths[i] for i in indices]
video_pts = [self.video_pts[i] for i in indices] video_pts = [self.video_pts[i] for i in indices]
if self._backend == "pyav":
video_fps = [self.video_fps[i] for i in indices] video_fps = [self.video_fps[i] for i in indices]
else:
info = [self.info[i] for i in indices]
metadata = { metadata = {
"video_paths": video_paths, "video_paths": video_paths,
"video_pts": video_pts, "video_pts": video_pts,
"video_fps": video_fps
} }
if self._backend == "pyav":
metadata.update({"video_fps": video_fps})
else:
metadata.update({"info": info})
return type(self)(video_paths, self.num_frames, self.step, self.frame_rate, return type(self)(video_paths, self.num_frames, self.step, self.frame_rate,
_precomputed_metadata=metadata, num_workers=self.num_workers, _precomputed_metadata=metadata, num_workers=self.num_workers,
_video_width=self._video_width, _video_width=self._video_width,
...@@ -212,22 +186,10 @@ class VideoClips(object): ...@@ -212,22 +186,10 @@ class VideoClips(object):
self.frame_rate = frame_rate self.frame_rate = frame_rate
self.clips = [] self.clips = []
self.resampling_idxs = [] self.resampling_idxs = []
if self._backend == "pyav":
for video_pts, fps in zip(self.video_pts, self.video_fps): for video_pts, fps in zip(self.video_pts, self.video_fps):
clips, idxs = self.compute_clips_for_video(video_pts, num_frames, step, fps, frame_rate) clips, idxs = self.compute_clips_for_video(video_pts, num_frames, step, fps, frame_rate)
self.clips.append(clips) self.clips.append(clips)
self.resampling_idxs.append(idxs) self.resampling_idxs.append(idxs)
else:
for video_pts, info in zip(self.video_pts, self.info):
if "video_fps" in info:
clips, idxs = self.compute_clips_for_video(
video_pts, num_frames, step, info["video_fps"], frame_rate)
self.clips.append(clips)
self.resampling_idxs.append(idxs)
else:
# properly handle the cases where video decoding fails
self.clips.append(torch.zeros(0, num_frames, dtype=torch.int64))
self.resampling_idxs.append(torch.zeros(0, dtype=torch.int64))
clip_lengths = torch.as_tensor([len(v) for v in self.clips]) clip_lengths = torch.as_tensor([len(v) for v in self.clips])
self.cumulative_sizes = clip_lengths.cumsum(0).tolist() self.cumulative_sizes = clip_lengths.cumsum(0).tolist()
...@@ -287,12 +249,28 @@ class VideoClips(object): ...@@ -287,12 +249,28 @@ class VideoClips(object):
video_path = self.video_paths[video_idx] video_path = self.video_paths[video_idx]
clip_pts = self.clips[video_idx][clip_idx] clip_pts = self.clips[video_idx][clip_idx]
if self._backend == "pyav": from torchvision import get_video_backend
backend = get_video_backend()
if backend == "pyav":
# check for invalid options
if self._video_width != 0:
raise ValueError("pyav backend doesn't support _video_width != 0")
if self._video_height != 0:
raise ValueError("pyav backend doesn't support _video_height != 0")
if self._video_min_dimension != 0:
raise ValueError("pyav backend doesn't support _video_min_dimension != 0")
if self._audio_samples != 0:
raise ValueError("pyav backend doesn't support _audio_samples != 0")
if backend == "pyav":
start_pts = clip_pts[0].item() start_pts = clip_pts[0].item()
end_pts = clip_pts[-1].item() end_pts = clip_pts[-1].item()
video, audio, info = read_video(video_path, start_pts, end_pts) video, audio, info = read_video(video_path, start_pts, end_pts)
else: else:
info = self.info[video_idx] info = _probe_video_from_file(video_path)
video_fps = info["video_fps"]
audio_fps = None
video_start_pts = clip_pts[0].item() video_start_pts = clip_pts[0].item()
video_end_pts = clip_pts[-1].item() video_end_pts = clip_pts[-1].item()
...@@ -313,6 +291,7 @@ class VideoClips(object): ...@@ -313,6 +291,7 @@ class VideoClips(object):
info["audio_timebase"], info["audio_timebase"],
math.ceil, math.ceil,
) )
audio_fps = info["audio_sample_rate"]
video, audio, info = _read_video_from_file( video, audio, info = _read_video_from_file(
video_path, video_path,
video_width=self._video_width, video_width=self._video_width,
...@@ -324,6 +303,11 @@ class VideoClips(object): ...@@ -324,6 +303,11 @@ class VideoClips(object):
audio_pts_range=(audio_start_pts, audio_end_pts), audio_pts_range=(audio_start_pts, audio_end_pts),
audio_timebase=audio_timebase, audio_timebase=audio_timebase,
) )
info = {"video_fps": video_fps}
if audio_fps is not None:
info["audio_fps"] = audio_fps
if self.frame_rate is not None: if self.frame_rate is not None:
resampling_idx = self.resampling_idxs[video_idx][clip_idx] resampling_idx = self.resampling_idxs[video_idx][clip_idx]
if isinstance(resampling_idx, torch.Tensor): if isinstance(resampling_idx, torch.Tensor):
......
...@@ -383,7 +383,7 @@ def _read_video(filename, start_pts=0, end_pts=None, pts_unit='pts'): ...@@ -383,7 +383,7 @@ def _read_video(filename, start_pts=0, end_pts=None, pts_unit='pts'):
audio_timebase = info['audio_timebase'] audio_timebase = info['audio_timebase']
audio_pts_range = get_pts(audio_timebase) audio_pts_range = get_pts(audio_timebase)
return _read_video_from_file( vframes, aframes, info = _read_video_from_file(
filename, filename,
read_video_stream=True, read_video_stream=True,
video_pts_range=video_pts_range, video_pts_range=video_pts_range,
...@@ -392,6 +392,13 @@ def _read_video(filename, start_pts=0, end_pts=None, pts_unit='pts'): ...@@ -392,6 +392,13 @@ def _read_video(filename, start_pts=0, end_pts=None, pts_unit='pts'):
audio_pts_range=audio_pts_range, audio_pts_range=audio_pts_range,
audio_timebase=audio_timebase, audio_timebase=audio_timebase,
) )
_info = {}
if has_video:
_info['video_fps'] = info['video_fps']
if has_audio:
_info['audio_fps'] = info['audio_sample_rate']
return vframes, aframes, _info
def _read_video_timestamps(filename, pts_unit='pts'): def _read_video_timestamps(filename, pts_unit='pts'):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment