Unverified Commit 7478e1a6 authored by Kai Chen's avatar Kai Chen Committed by GitHub
Browse files

Merge pull request #8 from ycxioooong/master

change video reader to 0-based indexing
parents 27c81690 6d51a941
......@@ -136,9 +136,9 @@ class VideoReader(object):
Returns:
ndarray or None: Return the frame if successful, otherwise None.
"""
pos = self._position + 1
# pos = self._position
if self._cache:
img = self._cache.get(pos)
img = self._cache.get(self._position)
if img is not None:
ret = True
else:
......@@ -146,38 +146,38 @@ class VideoReader(object):
self._set_real_position(self._position)
ret, img = self._vcap.read()
if ret:
self._cache.put(pos, img)
self._cache.put(self._position, img)
else:
ret, img = self._vcap.read()
if ret:
self._position = pos
self._position += 1
return img
def get_frame(self, frame_id):
"""Get frame by index.
Args:
frame_id (int): Index of the expected frame, 1-based.
frame_id (int): Index of the expected frame, 0-based.
Returns:
ndarray or None: Return the frame if successful, otherwise None.
"""
if frame_id <= 0 or frame_id > self._frame_cnt:
raise ValueError('"frame_id" must be between 1 and {}'.format(
if frame_id < 0 or frame_id >= self._frame_cnt:
raise ValueError('"frame_id" must be between 0 and {}'.format(
self._frame_cnt))
if frame_id == self._position + 1:
if frame_id == self._position:
return self.read()
if self._cache:
img = self._cache.get(frame_id)
if img is not None:
self._position = frame_id
self._position = frame_id + 1
return img
self._set_real_position(frame_id - 1)
self._set_real_position(frame_id)
ret, img = self._vcap.read()
if ret:
self._position += 1
if self._cache:
self._cache.put(self._position, img)
self._position += 1
return img
def current_frame(self):
......@@ -189,7 +189,7 @@ class VideoReader(object):
"""
if self._position == 0:
return None
return self._cache.get(self._position)
return self._cache.get(self._position - 1)
def cvt2frames(self,
frame_dir,
......
......@@ -58,11 +58,11 @@ class TestVideo(object):
v = mmcv.VideoReader(self.video_path)
img = v.read()
assert int(round(img.mean())) == 94
img = v.get_frame(64)
img = v.get_frame(63)
assert int(round(img.mean())) == 94
img = v[65]
assert int(round(img.mean())) == 205
img = v[64]
assert int(round(img.mean())) == 205
img = v[63]
assert int(round(img.mean())) == 94
img = v.read()
assert int(round(img.mean())) == 205
......@@ -82,7 +82,7 @@ class TestVideo(object):
for _ in range(10):
v.read()
assert v.position == 10
v.get_frame(100)
v.get_frame(99)
assert v.position == 100
def test_iterator(self):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment