Unverified Commit f5843099 authored by Prabhat Roy's avatar Prabhat Roy Committed by GitHub
Browse files

Fixed audio-video synchronisation problem in read_video() when using `pts` as unit (#3791)

* Fixed audio-video synchronisation problem in read_video() when using  as unit

* Addressed review comments

* Added unit test
parent 154283b1
...@@ -1238,6 +1238,45 @@ class TestVideoReader(unittest.TestCase): ...@@ -1238,6 +1238,45 @@ class TestVideoReader(unittest.TestCase):
) )
# FUTURE: check value of video / audio frames # FUTURE: check value of video / audio frames
def test_audio_video_sync(self):
"""Test if audio/video are synchronised with pyav output."""
for test_video, config in test_videos.items():
full_path = os.path.join(VIDEO_DIR, test_video)
container = av.open(full_path)
if not container.streams.audio:
# Skip if no audio stream
continue
start_pts_val, cutoff = 0, 1
if container.streams.video:
video = container.streams.video[0]
arr = []
for index, frame in enumerate(container.decode(video)):
if index == cutoff:
start_pts_val = frame.pts
if index >= cutoff:
arr.append(frame.to_rgb().to_ndarray())
visual, _, info = io.read_video(full_path, start_pts=start_pts_val, pts_unit='pts')
self.assertAlmostEqual(
config.video_fps, info['video_fps'], delta=0.0001
)
arr = torch.Tensor(arr)
if arr.shape == visual.shape:
self.assertGreaterEqual(
torch.mean(torch.isclose(visual.float(), arr, atol=1e-5).float()), 0.99)
container = av.open(full_path)
if container.streams.audio:
audio = container.streams.audio[0]
arr = []
for index, frame in enumerate(container.decode(audio)):
if index >= cutoff:
arr.append(frame.to_ndarray())
_, audio, _ = io.read_video(full_path, start_pts=start_pts_val, pts_unit='pts')
arr = torch.as_tensor(np.concatenate(arr, axis=1))
if arr.shape == audio.shape:
self.assertGreaterEqual(
torch.mean(torch.isclose(audio.float(), arr).float()), 0.99)
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
...@@ -471,6 +471,14 @@ def _probe_video_from_memory(video_data): ...@@ -471,6 +471,14 @@ def _probe_video_from_memory(video_data):
return info return info
def _convert_to_sec(start_pts, end_pts, pts_unit, time_base):
if pts_unit == 'pts':
start_pts = float(start_pts * time_base)
end_pts = float(end_pts * time_base)
pts_unit = 'sec'
return start_pts, end_pts, pts_unit
def _read_video(filename, start_pts=0, end_pts=None, pts_unit="pts"): def _read_video(filename, start_pts=0, end_pts=None, pts_unit="pts"):
if end_pts is None: if end_pts is None:
end_pts = float("inf") end_pts = float("inf")
...@@ -485,32 +493,43 @@ def _read_video(filename, start_pts=0, end_pts=None, pts_unit="pts"): ...@@ -485,32 +493,43 @@ def _read_video(filename, start_pts=0, end_pts=None, pts_unit="pts"):
has_video = info.has_video has_video = info.has_video
has_audio = info.has_audio has_audio = info.has_audio
video_pts_range = (0, -1)
video_timebase = default_timebase
audio_pts_range = (0, -1)
audio_timebase = default_timebase
time_base = default_timebase
if has_video:
video_timebase = Fraction(
info.video_timebase.numerator, info.video_timebase.denominator
)
time_base = video_timebase
if has_audio:
audio_timebase = Fraction(
info.audio_timebase.numerator, info.audio_timebase.denominator
)
time_base = time_base if time_base else audio_timebase
# video_timebase is the default time_base
start_pts_sec, end_pts_sec, pts_unit = _convert_to_sec(
start_pts, end_pts, pts_unit, time_base)
def get_pts(time_base): def get_pts(time_base):
start_offset = start_pts start_offset = start_pts_sec
end_offset = end_pts end_offset = end_pts_sec
if pts_unit == "sec": if pts_unit == "sec":
start_offset = int(math.floor(start_pts * (1 / time_base))) start_offset = int(math.floor(start_pts_sec * (1 / time_base)))
if end_offset != float("inf"): if end_offset != float("inf"):
end_offset = int(math.ceil(end_pts * (1 / time_base))) end_offset = int(math.ceil(end_pts_sec * (1 / time_base)))
if end_offset == float("inf"): if end_offset == float("inf"):
end_offset = -1 end_offset = -1
return start_offset, end_offset return start_offset, end_offset
video_pts_range = (0, -1)
video_timebase = default_timebase
if has_video: if has_video:
video_timebase = Fraction(
info.video_timebase.numerator, info.video_timebase.denominator
)
video_pts_range = get_pts(video_timebase) video_pts_range = get_pts(video_timebase)
audio_pts_range = (0, -1)
audio_timebase = default_timebase
if has_audio: if has_audio:
audio_timebase = Fraction(
info.audio_timebase.numerator, info.audio_timebase.denominator
)
audio_pts_range = get_pts(audio_timebase) audio_pts_range = get_pts(audio_timebase)
vframes, aframes, info = _read_video_from_file( vframes, aframes, info = _read_video_from_file(
......
...@@ -278,11 +278,19 @@ def read_video( ...@@ -278,11 +278,19 @@ def read_video(
try: try:
with av.open(filename, metadata_errors="ignore") as container: with av.open(filename, metadata_errors="ignore") as container:
time_base = _video_opt.default_timebase
if container.streams.video:
time_base = container.streams.video[0].time_base
elif container.streams.audio:
time_base = container.streams.audio[0].time_base
# video_timebase is the default time_base
start_pts_sec, end_pts_sec, pts_unit = _video_opt._convert_to_sec(
start_pts, end_pts, pts_unit, time_base)
if container.streams.video: if container.streams.video:
video_frames = _read_from_stream( video_frames = _read_from_stream(
container, container,
start_pts, start_pts_sec,
end_pts, end_pts_sec,
pts_unit, pts_unit,
container.streams.video[0], container.streams.video[0],
{"video": 0}, {"video": 0},
...@@ -295,8 +303,8 @@ def read_video( ...@@ -295,8 +303,8 @@ def read_video(
if container.streams.audio: if container.streams.audio:
audio_frames = _read_from_stream( audio_frames = _read_from_stream(
container, container,
start_pts, start_pts_sec,
end_pts, end_pts_sec,
pts_unit, pts_unit,
container.streams.audio[0], container.streams.audio[0],
{"audio": 0}, {"audio": 0},
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment