Commit 199a6ee2 authored by moto's avatar moto
Browse files

Fix issue with the missing video frame in StreamWriter (#2789)

Summary:
Addresses https://github.com/pytorch/audio/issues/2790.

Previously AVPacket objects had duration==0.

`av_interleaved_write_frame` function was inferring the duration of packets by
comparing them against the next ones but It could not infer the duration of
the last packet, as there is no subsequent frame, thus was omitting it from the final data.

This commit fixes it by explicitly setting packet duration = 1 (one frame)
only for video. (audio AVPacket contains multiple samples, so it's different.
To ensure the correctness for audio, the tests were added.)

Pull Request resolved: https://github.com/pytorch/audio/pull/2789

Reviewed By: xiaohui-zhang

Differential Revision: D40627439

Pulled By: mthrok

fbshipit-source-id: 4d0d827bff518c017b115445e03bdf0bf1e68320
parent 030646c0
...@@ -17,10 +17,6 @@ if is_ffmpeg_available(): ...@@ -17,10 +17,6 @@ if is_ffmpeg_available():
from torchaudio.io import StreamReader, StreamWriter from torchaudio.io import StreamReader, StreamWriter
# TODO:
# Get rid of StreamReader and use synthetic data.
def get_audio_chunk(fmt, sample_rate, num_channels): def get_audio_chunk(fmt, sample_rate, num_channels):
path = get_asset_path("nasa_13013.mp4") path = get_asset_path("nasa_13013.mp4")
s = StreamReader(path) s = StreamReader(path)
...@@ -255,3 +251,79 @@ class StreamWriterInterfaceTest(_MediaSourceMixin, TempDirMixin, TorchaudioTestC ...@@ -255,3 +251,79 @@ class StreamWriterInterfaceTest(_MediaSourceMixin, TempDirMixin, TorchaudioTestC
chunk = chunk[:, [2, 1, 0], :, :] chunk = chunk[:, [2, 1, 0], :, :]
expected = rgb_to_yuv_ccir(chunk) expected = rgb_to_yuv_ccir(chunk)
self.assertEqual(expected, result, atol=1, rtol=0) self.assertEqual(expected, result, atol=1, rtol=0)
@nested_params([25, 30], [(78, 96), (240, 426), (360, 640)], ["yuv444p", "rgb24"])
def test_video_num_frames(self, framerate, resolution, format):
"""Saving video as MP4 properly keep all the frames"""
ext = "mp4"
filename = f"test.{ext}"
h, w = resolution
# Write data
dst = self.get_dst(filename)
s = torchaudio.io.StreamWriter(dst=dst, format=ext)
s.add_video_stream(frame_rate=framerate, height=h, width=w, format=format)
chunk = torch.stack([torch.full((3, h, w), i, dtype=torch.uint8) for i in torch.linspace(0, 255, 256)])
with s.open():
s.write_video_chunk(0, chunk)
if self.test_fileobj:
dst.flush()
# Load data
s = torchaudio.io.StreamReader(src=self.get_temp_path(filename))
print(s.get_src_stream_info(0))
s.add_video_stream(-1)
s.process_all_packets()
(saved,) = s.pop_chunks()
assert saved.shape == chunk.shape
if format == "yuv444p":
# The following works if encoder_format is also yuv444p.
# Otherwise, the typical encoder format is yuv420p which incurs some data loss,
# and assertEqual fails.
#
# This is the case for libx264 encoder, but it's not always available.
# ffmpeg==4.2 from conda-forge (osx-arm64) comes with it but ffmpeg==5.1.2 does not.
# Since we do not have function to check the runtime availability of encoders,
# commenting it out for now.
# self.assertEqual(saved, chunk)
pass
@nested_params(
["wav", "mp3", "flac"],
[8000, 16000, 44100],
[1, 2],
)
def test_audio_num_frames(self, ext, sample_rate, num_channels):
""""""
filename = f"test.{ext}"
# Write data
dst = self.get_dst(filename)
s = torchaudio.io.StreamWriter(dst=dst, format=ext)
s.add_audio_stream(sample_rate=sample_rate, num_channels=num_channels)
freq = 300
duration = 60
theta = torch.linspace(0, freq * 2 * 3.14 * duration, sample_rate * duration)
if num_channels == 1:
chunk = torch.sin(theta).unsqueeze(-1)
else:
chunk = torch.stack([torch.sin(theta), torch.cos(theta)], dim=-1)
with s.open():
s.write_audio_chunk(0, chunk)
if self.test_fileobj:
dst.flush()
# Load data
s = torchaudio.io.StreamReader(src=self.get_temp_path(filename))
s.add_audio_stream(-1)
s.process_all_packets()
(saved,) = s.pop_chunks()
assert saved.shape == chunk.shape
if format in ["wav", "flac"]:
self.assertEqual(saved, chunk)
...@@ -644,6 +644,9 @@ void StreamWriter::process_frame( ...@@ -644,6 +644,9 @@ void StreamWriter::process_frame(
while (ret >= 0) { while (ret >= 0) {
ret = filter->get_frame(dst_frame); ret = filter->get_frame(dst_frame);
if (ret == AVERROR(EAGAIN) || ret == AVERROR_EOF) { if (ret == AVERROR(EAGAIN) || ret == AVERROR_EOF) {
if (ret == AVERROR_EOF) {
encode_frame(nullptr, c, st);
}
break; break;
} }
if (ret >= 0) { if (ret >= 0) {
...@@ -662,6 +665,23 @@ void StreamWriter::encode_frame( ...@@ -662,6 +665,23 @@ void StreamWriter::encode_frame(
while (ret >= 0) { while (ret >= 0) {
ret = avcodec_receive_packet(c, pkt); ret = avcodec_receive_packet(c, pkt);
if (ret == AVERROR(EAGAIN) || ret == AVERROR_EOF) { if (ret == AVERROR(EAGAIN) || ret == AVERROR_EOF) {
if (ret == AVERROR_EOF) {
// Note:
// av_interleaved_write_frame buffers the packets internally as needed
// to make sure the packets in the output file are properly interleaved
// in the order of increasing dts.
// https://ffmpeg.org/doxygen/3.4/group__lavf__encoding.html#ga37352ed2c63493c38219d935e71db6c1
// Passing nullptr will (forcefully) flush the queue, and this is
// necessary if users mal-configure the streams.
// Possible follow up: Add flush_buffer method?
// An alternative is to use `av_write_frame` functoin, but in that case
// client code is responsible for ordering packets, which makes it
// complicated to use StreamWriter
ret = av_interleaved_write_frame(pFormatContext, nullptr);
TORCH_CHECK(
ret >= 0, "Failed to flush packet (", av_err2string(ret), ").");
}
break; break;
} else { } else {
TORCH_CHECK( TORCH_CHECK(
...@@ -670,6 +690,14 @@ void StreamWriter::encode_frame( ...@@ -670,6 +690,14 @@ void StreamWriter::encode_frame(
av_err2string(ret), av_err2string(ret),
")."); ").");
} }
// https://github.com/pytorch/audio/issues/2790
// If this is not set, the last frame is not properly saved, as
// the encoder cannot figure out when the packet should finish.
if (pkt->duration == 0 && c->codec_type == AVMEDIA_TYPE_VIDEO) {
// 1 means that 1 frame (in codec time base, which is the frame rate)
// This has to be set before av_packet_rescale_ts bellow.
pkt->duration = 1;
}
av_packet_rescale_ts(pkt, c->time_base, st->time_base); av_packet_rescale_ts(pkt, c->time_base, st->time_base);
pkt->stream_index = st->index; pkt->stream_index = st->index;
...@@ -1027,7 +1055,6 @@ void StreamWriter::write_planar_video( ...@@ -1027,7 +1055,6 @@ void StreamWriter::write_planar_video(
} }
} }
// TODO: probably better to flush output streams in interweaving manner.
void StreamWriter::flush() { void StreamWriter::flush() {
for (auto& os : streams) { for (auto& os : streams) {
flush_stream(os); flush_stream(os);
...@@ -1037,8 +1064,9 @@ void StreamWriter::flush() { ...@@ -1037,8 +1064,9 @@ void StreamWriter::flush() {
void StreamWriter::flush_stream(OutputStream& os) { void StreamWriter::flush_stream(OutputStream& os) {
if (os.filter) { if (os.filter) {
process_frame(nullptr, os.filter, os.dst_frame, os.codec_ctx, os.stream); process_frame(nullptr, os.filter, os.dst_frame, os.codec_ctx, os.stream);
} else {
encode_frame(nullptr, os.codec_ctx, os.stream);
} }
encode_frame(nullptr, os.codec_ctx, os.stream);
} }
} // namespace ffmpeg } // namespace ffmpeg
} // namespace torchaudio } // namespace torchaudio
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment