Unverified Commit 93620a28 authored by Kai Chen's avatar Kai Chen Committed by GitHub
Browse files

Merge pull request #9 from ycxioooong/master

add slice support and unit test
parents c211ab13 6249f829
...@@ -163,8 +163,8 @@ class VideoReader(object): ...@@ -163,8 +163,8 @@ class VideoReader(object):
ndarray or None: Return the frame if successful, otherwise None. ndarray or None: Return the frame if successful, otherwise None.
""" """
if frame_id < 0 or frame_id >= self._frame_cnt: if frame_id < 0 or frame_id >= self._frame_cnt:
raise ValueError('"frame_id" must be between 0 and {}'.format( raise IndexError('"frame_id" must be between 0 and {}'.format(
self._frame_cnt)) self._frame_cnt - 1))
if frame_id == self._position: if frame_id == self._position:
return self.read() return self.read()
if self._cache: if self._cache:
...@@ -240,7 +240,12 @@ class VideoReader(object): ...@@ -240,7 +240,12 @@ class VideoReader(object):
def __getitem__(self, index): def __getitem__(self, index):
if isinstance(index, slice): if isinstance(index, slice):
raise RuntimeError('slice has not been supported yet') return [self.get_frame(i) for i in range(*index.indices(self.frame_cnt))]
# support negative indexing
if index < 0:
index += self.frame_cnt
if index < 0:
raise IndexError('index out of range')
return self.get_frame(index) return self.get_frame(index)
def __iter__(self): def __iter__(self):
......
...@@ -62,12 +62,45 @@ class TestVideo(object): ...@@ -62,12 +62,45 @@ class TestVideo(object):
assert int(round(img.mean())) == 94 assert int(round(img.mean())) == 94
img = v[64] img = v[64]
assert int(round(img.mean())) == 205 assert int(round(img.mean())) == 205
img = v[-104]
assert int(round(img.mean())) == 205
img = v[63] img = v[63]
assert int(round(img.mean())) == 94 assert int(round(img.mean())) == 94
img = v[-105]
assert int(round(img.mean())) == 94
img = v.read() img = v.read()
assert int(round(img.mean())) == 205 assert int(round(img.mean())) == 205
with pytest.raises(ValueError): with pytest.raises(IndexError):
v.get_frame(self.num_frames + 1) v.get_frame(self.num_frames + 1)
with pytest.raises(IndexError):
v[-self.num_frames - 1]
def test_slice(self):
v = mmcv.VideoReader(self.video_path)
imgs = v[-105:-103]
assert int(round(imgs[0].mean())) == 94
assert int(round(imgs[1].mean())) == 205
assert len(imgs) == 2
imgs = v[63:65]
assert int(round(imgs[0].mean())) == 94
assert int(round(imgs[1].mean())) == 205
assert len(imgs) == 2
imgs = v[64:62:-1]
assert int(round(imgs[0].mean())) == 205
assert int(round(imgs[1].mean())) == 94
assert len(imgs) == 2
imgs = v[:5]
assert len(imgs) == 5
for img in imgs:
assert int(round(img.mean())) == 94
imgs = v[165:]
assert len(imgs) == 3
for img in imgs:
assert int(round(img.mean())) == 0
imgs = v[-3:]
assert len(imgs) == 3
for img in imgs:
assert int(round(img.mean())) == 0
def test_current_frame(self): def test_current_frame(self):
v = mmcv.VideoReader(self.video_path) v = mmcv.VideoReader(self.video_path)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment