Unverified Commit f03f3140 authored by Francisco Massa's avatar Francisco Massa Committed by GitHub
Browse files

Add option to load VideoClips from pre-computed metadata (#1154)

parent fecd1385
...@@ -48,9 +48,12 @@ class VideoClips(object): ...@@ -48,9 +48,12 @@ class VideoClips(object):
on the resampled video on the resampled video
""" """
def __init__(self, video_paths, clip_length_in_frames=16, frames_between_clips=1, def __init__(self, video_paths, clip_length_in_frames=16, frames_between_clips=1,
frame_rate=None): frame_rate=None, _precomputed_metadata=None):
self.video_paths = video_paths self.video_paths = video_paths
if _precomputed_metadata is None:
self._compute_frame_pts() self._compute_frame_pts()
else:
self._init_from_metadata(_precomputed_metadata)
self.compute_clips(clip_length_in_frames, frames_between_clips, frame_rate) self.compute_clips(clip_length_in_frames, frames_between_clips, frame_rate)
def _compute_frame_pts(self): def _compute_frame_pts(self):
...@@ -62,6 +65,23 @@ class VideoClips(object): ...@@ -62,6 +65,23 @@ class VideoClips(object):
self.video_pts.append(torch.as_tensor(clips)) self.video_pts.append(torch.as_tensor(clips))
self.video_fps.append(fps) self.video_fps.append(fps)
def _init_from_metadata(self, metadata):
assert len(self.video_paths) == len(metadata["video_pts"])
assert len(self.video_paths) == len(metadata["video_fps"])
self.video_pts = metadata["video_pts"]
self.video_fps = metadata["video_fps"]
def subset(self, indices):
video_paths = [self.video_paths[i] for i in indices]
video_pts = [self.video_pts[i] for i in indices]
video_fps = [self.video_fps[i] for i in indices]
metadata = {
"video_pts": video_pts,
"video_fps": video_fps
}
return type(self)(video_paths, self.num_frames, self.step, self.frame_rate,
_precomputed_metadata=metadata)
@staticmethod @staticmethod
def compute_clips_for_video(video_pts, num_frames, step, fps, frame_rate): def compute_clips_for_video(video_pts, num_frames, step, fps, frame_rate):
if frame_rate is None: if frame_rate is None:
...@@ -155,14 +175,16 @@ class VideoClips(object): ...@@ -155,14 +175,16 @@ class VideoClips(object):
video_idx, clip_idx = self.get_clip_location(idx) video_idx, clip_idx = self.get_clip_location(idx)
video_path = self.video_paths[video_idx] video_path = self.video_paths[video_idx]
clip_pts = self.clips[video_idx][clip_idx] clip_pts = self.clips[video_idx][clip_idx]
video, audio, info = read_video(video_path, clip_pts[0].item(), clip_pts[-1].item()) start_pts = clip_pts[0].item()
end_pts = clip_pts[-1].item()
video, audio, info = read_video(video_path, start_pts, end_pts)
if self.frame_rate is not None: if self.frame_rate is not None:
resampling_idx = self.resampling_idxs[video_idx][clip_idx] resampling_idx = self.resampling_idxs[video_idx][clip_idx]
if isinstance(resampling_idx, torch.Tensor): if isinstance(resampling_idx, torch.Tensor):
resampling_idx = resampling_idx - resampling_idx[0] resampling_idx = resampling_idx - resampling_idx[0]
video = video[resampling_idx] video = video[resampling_idx]
info["video_fps"] = self.frame_rate info["video_fps"] = self.frame_rate
assert len(video) == self.num_frames assert len(video) == self.num_frames, "{} x {}".format(video.shape, self.num_frames)
return video, audio, info, video_idx return video, audio, info, video_idx
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment