Commit 49b01e3a authored by Zhicheng Yan's avatar Zhicheng Yan Committed by Francisco Massa
Browse files

add metadata to video dataset classes. bug fix. more robustness (#1376)

* add metadata to video dataset classes. bug fix. more robustness

* query video backend within VideoClips class

* Fix tests

* Fix lint
parent d02db177
......@@ -6,7 +6,6 @@ import unittest
from torchvision import io
from torchvision.datasets.video_utils import VideoClips, unfold
from torchvision import get_video_backend
from common_utils import get_tmp_dir
......@@ -62,23 +61,22 @@ class Tester(unittest.TestCase):
@unittest.skipIf(not io.video._av_available(), "this test requires av")
@unittest.skipIf('win' in sys.platform, 'temporarily disabled on Windows')
def test_video_clips(self):
_backend = get_video_backend()
with get_list_of_videos(num_videos=3) as video_list:
video_clips = VideoClips(video_list, 5, 5, _backend=_backend)
video_clips = VideoClips(video_list, 5, 5)
self.assertEqual(video_clips.num_clips(), 1 + 2 + 3)
for i, (v_idx, c_idx) in enumerate([(0, 0), (1, 0), (1, 1), (2, 0), (2, 1), (2, 2)]):
video_idx, clip_idx = video_clips.get_clip_location(i)
self.assertEqual(video_idx, v_idx)
self.assertEqual(clip_idx, c_idx)
video_clips = VideoClips(video_list, 6, 6, _backend=_backend)
video_clips = VideoClips(video_list, 6, 6)
self.assertEqual(video_clips.num_clips(), 0 + 1 + 2)
for i, (v_idx, c_idx) in enumerate([(1, 0), (2, 0), (2, 1)]):
video_idx, clip_idx = video_clips.get_clip_location(i)
self.assertEqual(video_idx, v_idx)
self.assertEqual(clip_idx, c_idx)
video_clips = VideoClips(video_list, 6, 1, _backend=_backend)
video_clips = VideoClips(video_list, 6, 1)
self.assertEqual(video_clips.num_clips(), 0 + (10 - 6 + 1) + (15 - 6 + 1))
for i, v_idx, c_idx in [(0, 1, 0), (4, 1, 4), (5, 2, 0), (6, 2, 1)]:
video_idx, clip_idx = video_clips.get_clip_location(i)
......@@ -87,9 +85,8 @@ class Tester(unittest.TestCase):
@unittest.skip("Moved to reference scripts for now")
def test_video_sampler(self):
_backend = get_video_backend()
with get_list_of_videos(num_videos=3, sizes=[25, 25, 25]) as video_list:
video_clips = VideoClips(video_list, 5, 5, _backend=_backend)
video_clips = VideoClips(video_list, 5, 5)
sampler = RandomClipSampler(video_clips, 3) # noqa: F821
self.assertEqual(len(sampler), 3 * 3)
indices = torch.tensor(list(iter(sampler)))
......@@ -100,9 +97,8 @@ class Tester(unittest.TestCase):
@unittest.skip("Moved to reference scripts for now")
def test_video_sampler_unequal(self):
_backend = get_video_backend()
with get_list_of_videos(num_videos=3, sizes=[10, 25, 25]) as video_list:
video_clips = VideoClips(video_list, 5, 5, _backend=_backend)
video_clips = VideoClips(video_list, 5, 5)
sampler = RandomClipSampler(video_clips, 3) # noqa: F821
self.assertEqual(len(sampler), 2 + 3 + 3)
indices = list(iter(sampler))
......@@ -120,11 +116,10 @@ class Tester(unittest.TestCase):
@unittest.skipIf(not io.video._av_available(), "this test requires av")
@unittest.skipIf('win' in sys.platform, 'temporarily disabled on Windows')
def test_video_clips_custom_fps(self):
_backend = get_video_backend()
with get_list_of_videos(num_videos=3, sizes=[12, 12, 12], fps=[3, 4, 6]) as video_list:
num_frames = 4
for fps in [1, 3, 4, 10]:
video_clips = VideoClips(video_list, num_frames, num_frames, fps, _backend=_backend)
video_clips = VideoClips(video_list, num_frames, num_frames, fps)
for i in range(video_clips.num_clips()):
video, audio, info, video_idx = video_clips.get_clip(i)
self.assertEqual(video.shape[0], num_frames)
......
import warnings
from torchvision import models
from torchvision import datasets
from torchvision import ops
......@@ -57,6 +59,9 @@ def set_video_backend(backend):
raise ValueError(
"Invalid video backend '%s'. Options are 'pyav' and 'video_reader'" % backend
)
if backend == "video_reader" and not io._HAS_VIDEO_OPT:
warnings.warn("video_reader video backend is not available")
else:
_video_backend = backend
......
import glob
import os
from .video_utils import VideoClips
from .utils import list_dir
from .folder import make_dataset
from .video_utils import VideoClips
from .vision import VisionDataset
......@@ -51,7 +51,8 @@ class HMDB51(VisionDataset):
def __init__(self, root, annotation_path, frames_per_clip, step_between_clips=1,
frame_rate=None, fold=1, train=True, transform=None,
_precomputed_metadata=None):
_precomputed_metadata=None, num_workers=1, _video_width=0,
_video_height=0, _video_min_dimension=0, _audio_samples=0):
super(HMDB51, self).__init__(root)
if not 1 <= fold <= 3:
raise ValueError("fold should be between 1 and 3, got {}".format(fold))
......@@ -71,11 +72,21 @@ class HMDB51(VisionDataset):
step_between_clips,
frame_rate,
_precomputed_metadata,
num_workers=num_workers,
_video_width=_video_width,
_video_height=_video_height,
_video_min_dimension=_video_min_dimension,
_audio_samples=_audio_samples,
)
self.video_clips_metadata = video_clips.metadata
self.indices = self._select_fold(video_list, annotation_path, fold, train)
self.video_clips = video_clips.subset(self.indices)
self.transform = transform
@property
def metadata(self):
return self.video_clips_metadata
def _select_fold(self, video_list, annotation_path, fold, train):
target_tag = 1 if train else 2
name = "*test_split{}.txt".format(fold)
......
from .video_utils import VideoClips
from .utils import list_dir
from .folder import make_dataset
from .video_utils import VideoClips
from .vision import VisionDataset
......@@ -37,7 +37,9 @@ class Kinetics400(VisionDataset):
"""
def __init__(self, root, frames_per_clip, step_between_clips=1, frame_rate=None,
extensions=('avi',), transform=None, _precomputed_metadata=None):
extensions=('avi',), transform=None, _precomputed_metadata=None,
num_workers=1, _video_width=0, _video_height=0,
_video_min_dimension=0, _audio_samples=0):
super(Kinetics400, self).__init__(root)
extensions = ('avi',)
......@@ -52,9 +54,18 @@ class Kinetics400(VisionDataset):
step_between_clips,
frame_rate,
_precomputed_metadata,
num_workers=num_workers,
_video_width=_video_width,
_video_height=_video_height,
_video_min_dimension=_video_min_dimension,
_audio_samples=_audio_samples,
)
self.transform = transform
@property
def metadata(self):
return self.video_clips.metadata
def __len__(self):
return self.video_clips.num_clips()
......
import glob
import os
from .video_utils import VideoClips
from .utils import list_dir
from .folder import make_dataset
from .video_utils import VideoClips
from .vision import VisionDataset
......@@ -44,7 +44,8 @@ class UCF101(VisionDataset):
def __init__(self, root, annotation_path, frames_per_clip, step_between_clips=1,
frame_rate=None, fold=1, train=True, transform=None,
_precomputed_metadata=None):
_precomputed_metadata=None, num_workers=1, _video_width=0,
_video_height=0, _video_min_dimension=0, _audio_samples=0):
super(UCF101, self).__init__(root)
if not 1 <= fold <= 3:
raise ValueError("fold should be between 1 and 3, got {}".format(fold))
......@@ -64,11 +65,21 @@ class UCF101(VisionDataset):
step_between_clips,
frame_rate,
_precomputed_metadata,
num_workers=num_workers,
_video_width=_video_width,
_video_height=_video_height,
_video_min_dimension=_video_min_dimension,
_audio_samples=_audio_samples,
)
self.video_clips_metadata = video_clips.metadata
self.indices = self._select_fold(video_list, annotation_path, fold, train)
self.video_clips = video_clips.subset(self.indices)
self.transform = transform
@property
def metadata(self):
return self.video_clips_metadata
def _select_fold(self, video_list, annotation_path, fold, train):
name = "train" if train else "test"
name = "{}list{:02d}.txt".format(name, fold)
......
......@@ -68,10 +68,18 @@ class VideoClips(object):
0 means that the data will be loaded in the main process. (default: 0)
"""
def __init__(self, video_paths, clip_length_in_frames=16, frames_between_clips=1,
frame_rate=None, _precomputed_metadata=None, num_workers=0, _backend="pyav"):
frame_rate=None, _precomputed_metadata=None, num_workers=0,
_video_width=0, _video_height=0, _video_min_dimension=0,
_audio_samples=0):
from torchvision import get_video_backend
self.video_paths = video_paths
self.num_workers = num_workers
self._backend = _backend
self._backend = get_video_backend()
self._video_width = _video_width
self._video_height = _video_height
self._video_min_dimension = _video_min_dimension
self._audio_samples = _audio_samples
if _precomputed_metadata is None:
self._compute_frame_pts()
......@@ -145,6 +153,7 @@ class VideoClips(object):
_metadata.update({"video_fps": self.video_fps})
else:
_metadata.update({"info": self.info})
return _metadata
def subset(self, indices):
video_paths = [self.video_paths[i] for i in indices]
......@@ -162,7 +171,11 @@ class VideoClips(object):
else:
metadata.update({"info": info})
return type(self)(video_paths, self.num_frames, self.step, self.frame_rate,
_precomputed_metadata=metadata)
_precomputed_metadata=metadata, num_workers=self.num_workers,
_video_width=self._video_width,
_video_height=self._video_height,
_video_min_dimension=self._video_min_dimension,
_audio_samples=self._audio_samples)
@staticmethod
def compute_clips_for_video(video_pts, num_frames, step, fps, frame_rate):
......@@ -206,9 +219,15 @@ class VideoClips(object):
self.resampling_idxs.append(idxs)
else:
for video_pts, info in zip(self.video_pts, self.info):
clips, idxs = self.compute_clips_for_video(video_pts, num_frames, step, info["video_fps"], frame_rate)
if "video_fps" in info:
clips, idxs = self.compute_clips_for_video(
video_pts, num_frames, step, info["video_fps"], frame_rate)
self.clips.append(clips)
self.resampling_idxs.append(idxs)
else:
# properly handle the cases where video decoding fails
self.clips.append(torch.zeros(0, num_frames, dtype=torch.int64))
self.resampling_idxs.append(torch.zeros(0, dtype=torch.int64))
clip_lengths = torch.as_tensor([len(v) for v in self.clips])
self.cumulative_sizes = clip_lengths.cumsum(0).tolist()
......@@ -296,8 +315,12 @@ class VideoClips(object):
)
video, audio, info = _read_video_from_file(
video_path,
video_width=self._video_width,
video_height=self._video_height,
video_min_dimension=self._video_min_dimension,
video_pts_range=(video_start_pts, video_end_pts),
video_timebase=info["video_timebase"],
audio_samples=self._audio_samples,
audio_pts_range=(audio_start_pts, audio_end_pts),
audio_timebase=audio_timebase,
)
......
from .video import write_video, read_video, read_video_timestamps
from ._video_opt import _read_video_from_file, _read_video_timestamps_from_file
from ._video_opt import _read_video_from_file, _read_video_timestamps_from_file, _HAS_VIDEO_OPT
__all__ = [
'write_video', 'read_video', 'read_video_timestamps',
'_read_video_from_file', '_read_video_timestamps_from_file',
'_read_video_from_file', '_read_video_timestamps_from_file', '_HAS_VIDEO_OPT',
]
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment