Commit fbf05f28 authored by Moto Hira's avatar Moto Hira Committed by Facebook GitHub Bot
Browse files

Fix PTS regression (#3131)

Summary:
Pull Request resolved: https://github.com/pytorch/audio/pull/3131

In https://github.com/pytorch/audio/pull/3122, the intermediate `num_frames` variable
is removed.

PTS can be incremented the same way, but the timing was wrong in #3122.
This commit fixes it.

Reviewed By: xiaohui-zhang

Differential Revision: D43712046

fbshipit-source-id: 2fe0082969296f4f3964e62e55b5325fcd45f4f9
parent 898db8c7
...@@ -4,6 +4,7 @@ import torchaudio ...@@ -4,6 +4,7 @@ import torchaudio
from parameterized import parameterized, parameterized_class from parameterized import parameterized, parameterized_class
from torchaudio_unittest.common_utils import ( from torchaudio_unittest.common_utils import (
get_asset_path, get_asset_path,
get_sinusoid,
is_ffmpeg_available, is_ffmpeg_available,
nested_params, nested_params,
rgb_to_yuv_ccir, rgb_to_yuv_ccir,
...@@ -352,3 +353,69 @@ class StreamWriterInterfaceTest(_MediaSourceMixin, TempDirMixin, TorchaudioTestC ...@@ -352,3 +353,69 @@ class StreamWriterInterfaceTest(_MediaSourceMixin, TempDirMixin, TorchaudioTestC
# Load data # Load data
reader = torchaudio.io.StreamReader(src=self.get_temp_path(filename)) reader = torchaudio.io.StreamReader(src=self.get_temp_path(filename))
assert reader.get_src_stream_info(0).frame_rate == frame_rate assert reader.get_src_stream_info(0).frame_rate == frame_rate
def test_video_pts_increment(self):
"""PTS values increment by the inverse of frame rate"""
ext = "mp4"
num_frames = 256
filename = f"test.{ext}"
frame_rate = 5000 / 167
width, height = 96, 128
# Write data
dst = self.get_dst(filename)
writer = torchaudio.io.StreamWriter(dst=dst, format=ext)
writer.add_video_stream(frame_rate=frame_rate, width=width, height=height)
video = torch.randint(256, (num_frames, 3, height, width), dtype=torch.uint8)
with writer.open():
writer.write_video_chunk(0, video)
if self.test_fileobj:
dst.flush()
reader = torchaudio.io.StreamReader(src=self.get_temp_path(filename))
reader.add_video_stream(1)
pts = [chunk.pts for (chunk,) in reader.stream()]
assert len(pts) == num_frames
for i, val in enumerate(pts):
expected = i / frame_rate
assert abs(val - expected) < 1e-10
def test_audio_pts_increment(self):
"""PTS values increment by the inverse of sample rate"""
ext = "wav"
filename = f"test.{ext}"
sample_rate = 8000
num_channels = 2
# Write data
dst = self.get_dst(filename)
writer = torchaudio.io.StreamWriter(dst=dst, format=ext)
writer.add_audio_stream(sample_rate=sample_rate, num_channels=num_channels)
audio = get_sinusoid(sample_rate=sample_rate, n_channels=num_channels, channels_first=False)
num_frames = audio.size(0)
with writer.open():
writer.write_audio_chunk(0, audio)
if self.test_fileobj:
dst.flush()
reader = torchaudio.io.StreamReader(src=self.get_temp_path(filename))
frames_per_chunk = sample_rate // 4
reader.add_audio_stream(frames_per_chunk, -1)
chunks = [chunk for (chunk,) in reader.stream()]
expected = num_frames // (frames_per_chunk)
assert len(chunks) == expected, f"Expected {expected} elements. Found {len(chunks)}"
num_samples = 0
for chunk in chunks:
expected = num_samples / sample_rate
num_samples += chunk.size(0)
print(chunk.pts, expected)
assert abs(chunk.pts - expected) < 1e-10
...@@ -135,9 +135,9 @@ void AudioOutputStream::write_chunk(const torch::Tensor& waveform) { ...@@ -135,9 +135,9 @@ void AudioOutputStream::write_chunk(const torch::Tensor& waveform) {
"Internal Error: frame is not writable."); "Internal Error: frame is not writable.");
memcpy(src_frame->data[0], chunk.data_ptr(), byte_size); memcpy(src_frame->data[0], chunk.data_ptr(), byte_size);
src_frame->pts += av_rescale_q(num_frames, time_base, codec_ctx->time_base);
src_frame->nb_samples = num_frames; src_frame->nb_samples = num_frames;
process_frame(src_frame); process_frame(src_frame);
src_frame->pts += av_rescale_q(num_frames, time_base, codec_ctx->time_base);
} }
} }
......
...@@ -52,8 +52,8 @@ VideoOutputStream::VideoOutputStream( ...@@ -52,8 +52,8 @@ VideoOutputStream::VideoOutputStream(
void VideoOutputStream::write_chunk(const torch::Tensor& frames) { void VideoOutputStream::write_chunk(const torch::Tensor& frames) {
for (const auto& frame : converter.convert(frames)) { for (const auto& frame : converter.convert(frames)) {
frame->pts += 1;
process_frame(frame); process_frame(frame);
frame->pts += 1;
} }
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment