Unverified Commit f5843099 authored by Prabhat Roy's avatar Prabhat Roy Committed by GitHub
Browse files

Fixed audio-video synchronisation problem in read_video() when using `pts` as unit (#3791)

* Fixed audio-video synchronisation problem in read_video() when using  as unit

* Addressed review comments

* Added unit test
parent 154283b1
......@@ -1238,6 +1238,45 @@ class TestVideoReader(unittest.TestCase):
)
# FUTURE: check value of video / audio frames
def test_audio_video_sync(self):
"""Test if audio/video are synchronised with pyav output."""
for test_video, config in test_videos.items():
full_path = os.path.join(VIDEO_DIR, test_video)
container = av.open(full_path)
if not container.streams.audio:
# Skip if no audio stream
continue
start_pts_val, cutoff = 0, 1
if container.streams.video:
video = container.streams.video[0]
arr = []
for index, frame in enumerate(container.decode(video)):
if index == cutoff:
start_pts_val = frame.pts
if index >= cutoff:
arr.append(frame.to_rgb().to_ndarray())
visual, _, info = io.read_video(full_path, start_pts=start_pts_val, pts_unit='pts')
self.assertAlmostEqual(
config.video_fps, info['video_fps'], delta=0.0001
)
arr = torch.Tensor(arr)
if arr.shape == visual.shape:
self.assertGreaterEqual(
torch.mean(torch.isclose(visual.float(), arr, atol=1e-5).float()), 0.99)
container = av.open(full_path)
if container.streams.audio:
audio = container.streams.audio[0]
arr = []
for index, frame in enumerate(container.decode(audio)):
if index >= cutoff:
arr.append(frame.to_ndarray())
_, audio, _ = io.read_video(full_path, start_pts=start_pts_val, pts_unit='pts')
arr = torch.as_tensor(np.concatenate(arr, axis=1))
if arr.shape == audio.shape:
self.assertGreaterEqual(
torch.mean(torch.isclose(audio.float(), arr).float()), 0.99)
if __name__ == "__main__":
unittest.main()
......@@ -471,6 +471,14 @@ def _probe_video_from_memory(video_data):
return info
def _convert_to_sec(start_pts, end_pts, pts_unit, time_base):
if pts_unit == 'pts':
start_pts = float(start_pts * time_base)
end_pts = float(end_pts * time_base)
pts_unit = 'sec'
return start_pts, end_pts, pts_unit
def _read_video(filename, start_pts=0, end_pts=None, pts_unit="pts"):
if end_pts is None:
end_pts = float("inf")
......@@ -485,32 +493,43 @@ def _read_video(filename, start_pts=0, end_pts=None, pts_unit="pts"):
has_video = info.has_video
has_audio = info.has_audio
video_pts_range = (0, -1)
video_timebase = default_timebase
audio_pts_range = (0, -1)
audio_timebase = default_timebase
time_base = default_timebase
if has_video:
video_timebase = Fraction(
info.video_timebase.numerator, info.video_timebase.denominator
)
time_base = video_timebase
if has_audio:
audio_timebase = Fraction(
info.audio_timebase.numerator, info.audio_timebase.denominator
)
time_base = time_base if time_base else audio_timebase
# video_timebase is the default time_base
start_pts_sec, end_pts_sec, pts_unit = _convert_to_sec(
start_pts, end_pts, pts_unit, time_base)
def get_pts(time_base):
start_offset = start_pts
end_offset = end_pts
start_offset = start_pts_sec
end_offset = end_pts_sec
if pts_unit == "sec":
start_offset = int(math.floor(start_pts * (1 / time_base)))
start_offset = int(math.floor(start_pts_sec * (1 / time_base)))
if end_offset != float("inf"):
end_offset = int(math.ceil(end_pts * (1 / time_base)))
end_offset = int(math.ceil(end_pts_sec * (1 / time_base)))
if end_offset == float("inf"):
end_offset = -1
return start_offset, end_offset
video_pts_range = (0, -1)
video_timebase = default_timebase
if has_video:
video_timebase = Fraction(
info.video_timebase.numerator, info.video_timebase.denominator
)
video_pts_range = get_pts(video_timebase)
audio_pts_range = (0, -1)
audio_timebase = default_timebase
if has_audio:
audio_timebase = Fraction(
info.audio_timebase.numerator, info.audio_timebase.denominator
)
audio_pts_range = get_pts(audio_timebase)
vframes, aframes, info = _read_video_from_file(
......
......@@ -278,11 +278,19 @@ def read_video(
try:
with av.open(filename, metadata_errors="ignore") as container:
time_base = _video_opt.default_timebase
if container.streams.video:
time_base = container.streams.video[0].time_base
elif container.streams.audio:
time_base = container.streams.audio[0].time_base
# video_timebase is the default time_base
start_pts_sec, end_pts_sec, pts_unit = _video_opt._convert_to_sec(
start_pts, end_pts, pts_unit, time_base)
if container.streams.video:
video_frames = _read_from_stream(
container,
start_pts,
end_pts,
start_pts_sec,
end_pts_sec,
pts_unit,
container.streams.video[0],
{"video": 0},
......@@ -295,8 +303,8 @@ def read_video(
if container.streams.audio:
audio_frames = _read_from_stream(
container,
start_pts,
end_pts,
start_pts_sec,
end_pts_sec,
pts_unit,
container.streams.audio[0],
{"audio": 0},
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment