"git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "eb5267f3778e30905dc77352195986d928156ae5"
Unverified Commit e3f1a822 authored by Prabhat Roy's avatar Prabhat Roy Committed by GitHub
Browse files

Improve test_videoapi (#5497)

parent c29a20ab
...@@ -52,139 +52,131 @@ test_videos = { ...@@ -52,139 +52,131 @@ test_videos = {
@pytest.mark.skipif(_HAS_VIDEO_OPT is False, reason="Didn't compile with ffmpeg") @pytest.mark.skipif(_HAS_VIDEO_OPT is False, reason="Didn't compile with ffmpeg")
class TestVideoApi: class TestVideoApi:
@pytest.mark.skipif(av is None, reason="PyAV unavailable") @pytest.mark.skipif(av is None, reason="PyAV unavailable")
def test_frame_reading(self): @pytest.mark.parametrize("test_video", test_videos.keys())
for test_video, config in test_videos.items(): def test_frame_reading(self, test_video):
full_path = os.path.join(VIDEO_DIR, test_video) full_path = os.path.join(VIDEO_DIR, test_video)
with av.open(full_path) as av_reader:
with av.open(full_path) as av_reader: if av_reader.streams.video:
is_video = True if av_reader.streams.video else False av_frames, vr_frames = [], []
av_pts, vr_pts = [], []
if is_video: # get av frames
av_frames, vr_frames = [], [] for av_frame in av_reader.decode(av_reader.streams.video[0]):
av_pts, vr_pts = [], [] av_frames.append(torch.tensor(av_frame.to_rgb().to_ndarray()).permute(2, 0, 1))
# get av frames av_pts.append(av_frame.pts * av_frame.time_base)
for av_frame in av_reader.decode(av_reader.streams.video[0]):
av_frames.append(torch.tensor(av_frame.to_rgb().to_ndarray()).permute(2, 0, 1)) # get vr frames
av_pts.append(av_frame.pts * av_frame.time_base) video_reader = VideoReader(full_path, "video")
for vr_frame in video_reader:
# get vr frames vr_frames.append(vr_frame["data"])
video_reader = VideoReader(full_path, "video") vr_pts.append(vr_frame["pts"])
for vr_frame in video_reader:
vr_frames.append(vr_frame["data"]) # same number of frames
vr_pts.append(vr_frame["pts"]) assert len(vr_frames) == len(av_frames)
assert len(vr_pts) == len(av_pts)
# same number of frames
assert len(vr_frames) == len(av_frames) # compare the frames and ptss
assert len(vr_pts) == len(av_pts) for i in range(len(vr_frames)):
assert float(av_pts[i]) == approx(vr_pts[i], abs=0.1)
# compare the frames and ptss mean_delta = torch.mean(torch.abs(av_frames[i].float() - vr_frames[i].float()))
for i in range(len(vr_frames)): # on average the difference is very small and caused
assert float(av_pts[i]) == approx(vr_pts[i], abs=0.1) # by decoding (around 1%)
mean_delta = torch.mean(torch.abs(av_frames[i].float() - vr_frames[i].float())) # TODO: asses empirically how to set this? atm it's 1%
# on average the difference is very small and caused # averaged over all frames
# by decoding (around 1%) assert mean_delta.item() < 2.55
# TODO: asses empirically how to set this? atm it's 1%
# averaged over all frames del vr_frames, av_frames, vr_pts, av_pts
assert mean_delta.item() < 2.55
# test audio reading compared to PYAV
del vr_frames, av_frames, vr_pts, av_pts with av.open(full_path) as av_reader:
if av_reader.streams.audio:
# test audio reading compared to PYAV av_frames, vr_frames = [], []
with av.open(full_path) as av_reader: av_pts, vr_pts = [], []
is_audio = True if av_reader.streams.audio else False # get av frames
for av_frame in av_reader.decode(av_reader.streams.audio[0]):
if is_audio: av_frames.append(torch.tensor(av_frame.to_ndarray()).permute(1, 0))
av_frames, vr_frames = [], [] av_pts.append(av_frame.pts * av_frame.time_base)
av_pts, vr_pts = [], [] av_reader.close()
# get av frames
for av_frame in av_reader.decode(av_reader.streams.audio[0]): # get vr frames
av_frames.append(torch.tensor(av_frame.to_ndarray()).permute(1, 0)) video_reader = VideoReader(full_path, "audio")
av_pts.append(av_frame.pts * av_frame.time_base) for vr_frame in video_reader:
av_reader.close() vr_frames.append(vr_frame["data"])
vr_pts.append(vr_frame["pts"])
# get vr frames
video_reader = VideoReader(full_path, "audio") # same number of frames
for vr_frame in video_reader: assert len(vr_frames) == len(av_frames)
vr_frames.append(vr_frame["data"]) assert len(vr_pts) == len(av_pts)
vr_pts.append(vr_frame["pts"])
# compare the frames and ptss
# same number of frames for i in range(len(vr_frames)):
assert len(vr_frames) == len(av_frames) assert float(av_pts[i]) == approx(vr_pts[i], abs=0.1)
assert len(vr_pts) == len(av_pts) max_delta = torch.max(torch.abs(av_frames[i].float() - vr_frames[i].float()))
# we assure that there is never more than 1% difference in signal
# compare the frames and ptss assert max_delta.item() < 0.001
for i in range(len(vr_frames)):
assert float(av_pts[i]) == approx(vr_pts[i], abs=0.1) @pytest.mark.parametrize("test_video,config", test_videos.items())
max_delta = torch.max(torch.abs(av_frames[i].float() - vr_frames[i].float())) def test_metadata(self, test_video, config):
# we assure that there is never more than 1% difference in signal
assert max_delta.item() < 0.001
def test_metadata(self):
""" """
Test that the metadata returned via pyav corresponds to the one returned Test that the metadata returned via pyav corresponds to the one returned
by the new video decoder API by the new video decoder API
""" """
for test_video, config in test_videos.items(): full_path = os.path.join(VIDEO_DIR, test_video)
full_path = os.path.join(VIDEO_DIR, test_video) reader = VideoReader(full_path, "video")
reader = VideoReader(full_path, "video") reader_md = reader.get_metadata()
reader_md = reader.get_metadata() assert config.video_fps == approx(reader_md["video"]["fps"][0], abs=0.0001)
assert config.video_fps == approx(reader_md["video"]["fps"][0], abs=0.0001) assert config.duration == approx(reader_md["video"]["duration"][0], abs=0.5)
assert config.duration == approx(reader_md["video"]["duration"][0], abs=0.5)
@pytest.mark.parametrize("test_video", test_videos.keys())
def test_seek_start(self): def test_seek_start(self, test_video):
for test_video, config in test_videos.items(): full_path = os.path.join(VIDEO_DIR, test_video)
full_path = os.path.join(VIDEO_DIR, test_video) video_reader = VideoReader(full_path, "video")
num_frames = 0
video_reader = VideoReader(full_path, "video") for _ in video_reader:
num_frames += 1
# now seek the container to 0 and do it again
# It's often that starting seek can be inprecise
# this way and it doesn't start at 0
video_reader.seek(0)
start_num_frames = 0
for _ in video_reader:
start_num_frames += 1
assert start_num_frames == num_frames
# now seek the container to < 0 to check for unexpected behaviour
video_reader.seek(-1)
start_num_frames = 0
for _ in video_reader:
start_num_frames += 1
assert start_num_frames == num_frames
@pytest.mark.parametrize("test_video", test_videos.keys())
def test_accurateseek_middle(self, test_video):
full_path = os.path.join(VIDEO_DIR, test_video)
stream = "video"
video_reader = VideoReader(full_path, stream)
md = video_reader.get_metadata()
duration = md[stream]["duration"][0]
if duration is not None:
num_frames = 0 num_frames = 0
for frame in video_reader: for _ in video_reader:
num_frames += 1 num_frames += 1
# now seek the container to 0 and do it again video_reader.seek(duration / 2)
# It's often that starting seek can be inprecise middle_num_frames = 0
# this way and it doesn't start at 0 for _ in video_reader:
video_reader.seek(0) middle_num_frames += 1
start_num_frames = 0
for frame in video_reader:
start_num_frames += 1
assert start_num_frames == num_frames
# now seek the container to < 0 to check for unexpected behaviour
video_reader.seek(-1)
start_num_frames = 0
for frame in video_reader:
start_num_frames += 1
assert start_num_frames == num_frames
def test_accurateseek_middle(self):
for test_video, config in test_videos.items():
full_path = os.path.join(VIDEO_DIR, test_video)
stream = "video" assert middle_num_frames < num_frames
video_reader = VideoReader(full_path, stream) assert middle_num_frames == approx(num_frames // 2, abs=1)
md = video_reader.get_metadata()
duration = md[stream]["duration"][0]
if duration is not None:
num_frames = 0 video_reader.seek(duration / 2)
for frame in video_reader: frame = next(video_reader)
num_frames += 1 lb = duration / 2 - 1 / md[stream]["fps"][0]
ub = duration / 2 + 1 / md[stream]["fps"][0]
video_reader.seek(duration / 2) assert (lb <= frame["pts"]) and (ub >= frame["pts"])
middle_num_frames = 0
for frame in video_reader:
middle_num_frames += 1
assert middle_num_frames < num_frames
assert middle_num_frames == approx(num_frames // 2, abs=1)
video_reader.seek(duration / 2)
frame = next(video_reader)
lb = duration / 2 - 1 / md[stream]["fps"][0]
ub = duration / 2 + 1 / md[stream]["fps"][0]
assert (lb <= frame["pts"]) and (ub >= frame["pts"])
def test_fate_suite(self): def test_fate_suite(self):
# TODO: remove the try-except statement once the connectivity issues are resolved # TODO: remove the try-except statement once the connectivity issues are resolved
...@@ -199,41 +191,41 @@ class TestVideoApi: ...@@ -199,41 +191,41 @@ class TestVideoApi:
os.remove(video_path) os.remove(video_path)
@pytest.mark.skipif(av is None, reason="PyAV unavailable") @pytest.mark.skipif(av is None, reason="PyAV unavailable")
def test_keyframe_reading(self): @pytest.mark.parametrize("test_video,config", test_videos.items())
for test_video, config in test_videos.items(): def test_keyframe_reading(self, test_video, config):
full_path = os.path.join(VIDEO_DIR, test_video) full_path = os.path.join(VIDEO_DIR, test_video)
av_reader = av.open(full_path)
# reduce streams to only keyframes
av_stream = av_reader.streams.video[0]
av_stream.codec_context.skip_frame = "NONKEY"
av_keyframes = [] av_reader = av.open(full_path)
vr_keyframes = [] # reduce streams to only keyframes
if av_reader.streams.video: av_stream = av_reader.streams.video[0]
av_stream.codec_context.skip_frame = "NONKEY"
# get all keyframes using pyav. Then, seek randomly into video reader av_keyframes = []
# and assert that all the returned values are in AV_KEYFRAMES vr_keyframes = []
if av_reader.streams.video:
for av_frame in av_reader.decode(av_stream): # get all keyframes using pyav. Then, seek randomly into video reader
av_keyframes.append(float(av_frame.pts * av_frame.time_base)) # and assert that all the returned values are in AV_KEYFRAMES
if len(av_keyframes) > 1: for av_frame in av_reader.decode(av_stream):
video_reader = VideoReader(full_path, "video") av_keyframes.append(float(av_frame.pts * av_frame.time_base))
for i in range(1, len(av_keyframes)):
seek_val = (av_keyframes[i] + av_keyframes[i - 1]) / 2
data = next(video_reader.seek(seek_val, True))
vr_keyframes.append(data["pts"])
data = next(video_reader.seek(config.duration, True)) if len(av_keyframes) > 1:
video_reader = VideoReader(full_path, "video")
for i in range(1, len(av_keyframes)):
seek_val = (av_keyframes[i] + av_keyframes[i - 1]) / 2
data = next(video_reader.seek(seek_val, True))
vr_keyframes.append(data["pts"]) vr_keyframes.append(data["pts"])
assert len(av_keyframes) == len(vr_keyframes) data = next(video_reader.seek(config.duration, True))
# NOTE: this video gets different keyframe with different vr_keyframes.append(data["pts"])
# loaders (0.333 pyav, 0.666 for us)
if test_video != "TrumanShow_wave_f_nm_np1_fr_med_26.avi": assert len(av_keyframes) == len(vr_keyframes)
for i in range(len(av_keyframes)): # NOTE: this video gets different keyframe with different
assert av_keyframes[i] == approx(vr_keyframes[i], rel=0.001) # loaders (0.333 pyav, 0.666 for us)
if test_video != "TrumanShow_wave_f_nm_np1_fr_med_26.avi":
for i in range(len(av_keyframes)):
assert av_keyframes[i] == approx(vr_keyframes[i], rel=0.001)
if __name__ == "__main__": if __name__ == "__main__":
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment