Commit 757ecfb0 authored by Oana Florescu's avatar Oana Florescu Committed by Francisco Massa
Browse files

VideoClips windows fixes (#1661)

* remove windows skips from video_utils tests, now that they pass

* replace lambda in videoclips in order to be pickled on windows and update tests
parent 17ea1482
...@@ -59,10 +59,9 @@ class Tester(unittest.TestCase): ...@@ -59,10 +59,9 @@ class Tester(unittest.TestCase):
self.assertTrue(r.equal(expected)) self.assertTrue(r.equal(expected))
@unittest.skipIf(not io.video._av_available(), "this test requires av") @unittest.skipIf(not io.video._av_available(), "this test requires av")
@unittest.skipIf(sys.platform == 'win32', 'temporarily disabled on Windows')
def test_video_clips(self): def test_video_clips(self):
with get_list_of_videos(num_videos=3) as video_list: with get_list_of_videos(num_videos=3) as video_list:
video_clips = VideoClips(video_list, 5, 5) video_clips = VideoClips(video_list, 5, 5, num_workers=2)
self.assertEqual(video_clips.num_clips(), 1 + 2 + 3) self.assertEqual(video_clips.num_clips(), 1 + 2 + 3)
for i, (v_idx, c_idx) in enumerate([(0, 0), (1, 0), (1, 1), (2, 0), (2, 1), (2, 2)]): for i, (v_idx, c_idx) in enumerate([(0, 0), (1, 0), (1, 1), (2, 0), (2, 1), (2, 2)]):
video_idx, clip_idx = video_clips.get_clip_location(i) video_idx, clip_idx = video_clips.get_clip_location(i)
...@@ -84,12 +83,11 @@ class Tester(unittest.TestCase): ...@@ -84,12 +83,11 @@ class Tester(unittest.TestCase):
self.assertEqual(clip_idx, c_idx) self.assertEqual(clip_idx, c_idx)
@unittest.skipIf(not io.video._av_available(), "this test requires av") @unittest.skipIf(not io.video._av_available(), "this test requires av")
@unittest.skipIf(sys.platform == 'win32', 'temporarily disabled on Windows')
def test_video_clips_custom_fps(self): def test_video_clips_custom_fps(self):
with get_list_of_videos(num_videos=3, sizes=[12, 12, 12], fps=[3, 4, 6]) as video_list: with get_list_of_videos(num_videos=3, sizes=[12, 12, 12], fps=[3, 4, 6]) as video_list:
num_frames = 4 num_frames = 4
for fps in [1, 3, 4, 10]: for fps in [1, 3, 4, 10]:
video_clips = VideoClips(video_list, num_frames, num_frames, fps) video_clips = VideoClips(video_list, num_frames, num_frames, fps, num_workers=2)
for i in range(video_clips.num_clips()): for i in range(video_clips.num_clips()):
video, audio, info, video_idx = video_clips.get_clip(i) video, audio, info, video_idx = video_clips.get_clip(i)
self.assertEqual(video.shape[0], num_frames) self.assertEqual(video.shape[0], num_frames)
......
...@@ -104,6 +104,9 @@ class VideoClips(object): ...@@ -104,6 +104,9 @@ class VideoClips(object):
self._init_from_metadata(_precomputed_metadata) self._init_from_metadata(_precomputed_metadata)
self.compute_clips(clip_length_in_frames, frames_between_clips, frame_rate) self.compute_clips(clip_length_in_frames, frames_between_clips, frame_rate)
def _collate_fn(self, x):
return x
def _compute_frame_pts(self): def _compute_frame_pts(self):
self.video_pts = [] self.video_pts = []
self.video_fps = [] self.video_fps = []
...@@ -115,7 +118,7 @@ class VideoClips(object): ...@@ -115,7 +118,7 @@ class VideoClips(object):
_DummyDataset(self.video_paths), _DummyDataset(self.video_paths),
batch_size=16, batch_size=16,
num_workers=self.num_workers, num_workers=self.num_workers,
collate_fn=lambda x: x) collate_fn=self._collate_fn)
with tqdm(total=len(dl)) as pbar: with tqdm(total=len(dl)) as pbar:
for batch in dl: for batch in dl:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment